[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@ -504,62 +504,72 @@ public:
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc);
DG_HOST_ASSERT(status == cudaSuccess);
}
};
template <typename T>
static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m) {
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_K, BLOCK_M, BLOCK_K);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_k, block_m, block_k);
}
template <typename T>
static CUtensorMap make_2d_tma_b_desc(T* global_address) {
return make_2d_tma_desc(global_address, Layout::ColMajor,
SHAPE_K, SHAPE_N * (kGemmType != GemmType::Normal ? kNumGroups : 1), BLOCK_K, BLOCK_N);
}
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_b_desc(T *global_address, uint32_t shape_k, uint32_t shape_n, uint32_t block_k, uint32_t block_n, uint32_t num_groups = 1) {
return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k,
shape_n * (kGemmType != GemmType::Normal ? num_groups : 1),
block_k, block_n);
}
template <typename T>
static CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m) {
template <typename T, GemmType kGemmType, uint32_t kSwizzleDMode>
static CUtensorMap make_2d_tma_d_desc(T *global_address, uint32_t shape_m, uint32_t shape_n, uint32_t block_m, uint32_t block_n, uint32_t num_groups = 1) {
auto swizzle_mode = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
if constexpr (kSwizzleDMode == 32) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128) swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
if constexpr (kSwizzleDMode == 32)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_32B;
if constexpr (kSwizzleDMode == 64)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_64B;
if constexpr (kSwizzleDMode == 128)
swizzle_mode = CU_TENSOR_MAP_SWIZZLE_128B;
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode` bytes
// So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), SHAPE_N,
BLOCK_M, kSwizzleDMode == 0 ? BLOCK_N : kSwizzleDMode / sizeof(T),
swizzle_mode);
}
// Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
// bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(
global_address, Layout::RowMajor,
shape_m * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
shape_n, block_m,
kSwizzleDMode == 0 ? block_n : kSwizzleDMode / sizeof(T), swizzle_mode);
}
template <typename T>
static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m) {
template <typename T, GemmType kGemmType>
static CUtensorMap make_2d_tma_scales_a_desc(T *global_address, uint32_t shape_m, uint32_t shape_k, uint32_t block_m, uint32_t block_k, uint32_t num_groups = 1) {
// Make TMA aligned to 16 bytes
constexpr uint32_t kAlignment = 16 / sizeof(T);
shape_m = ceil_div(shape_m, kAlignment) * kAlignment;
return make_2d_tma_desc(global_address, Layout::ColMajor,
shape_m, ceil_div(SHAPE_K, BLOCK_K) * (kGemmType == GemmType::GroupedMasked ? kNumGroups : 1), BLOCK_M, 1,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
return make_2d_tma_desc(
global_address, Layout::ColMajor, shape_m,
ceil_div(shape_k, block_k) * (kGemmType == GemmType::GroupedMasked ? num_groups : 1),
block_m, 1, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
}
template <typename T>
static CUtensorMap make_2d_tma_desc(
T* global_address, Layout layout,
uint32_t gmem_rows, uint32_t gmem_cols,
uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
template <typename T>
static CUtensorMap
make_2d_tma_desc(T *global_address, Layout layout, uint32_t gmem_rows,
uint32_t gmem_cols, uint32_t smem_rows, uint32_t smem_cols,
CUtensorMapSwizzle swizzle_type =
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) {
if (layout == Layout::RowMajor) {
uint64_t gmem_dim[2] = {gmem_cols, gmem_rows};
uint32_t smem_dim[2] = {smem_cols, smem_rows};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_cols * sizeof(T), smem_dim, swizzle_type);
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_cols * sizeof(T), smem_dim, swizzle_type);
} else {
uint64_t gmem_dim[2] = {gmem_rows, gmem_cols};
uint32_t smem_dim[2] = {smem_rows, smem_cols};
return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, swizzle_type);
return make_2d_tma_copy_desc(global_address, gmem_dim,
gmem_rows * sizeof(T), smem_dim, swizzle_type);
}
}
};
}
}; // namespace deep_gemm

View File

@ -1,6 +1,8 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cuda.h>
#endif
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>

View File

@ -0,0 +1,69 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef NVRTC_JIT_COMPILATION
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long;
namespace std
{
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant; // using injected-class-name
__device__ constexpr operator value_type() const noexcept
{
return value;
}
__device__ constexpr value_type operator()() const noexcept
{
return value;
} // since c++14
};
using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>;
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;
} // namespace std
#endif

View File

@ -1,6 +1,9 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <cassert>
#endif
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_fp8.h>

View File

@ -1,5 +1,6 @@
#pragma once
#ifndef NVRTC_JIT_COMPILATION
#include <exception>
#ifdef __CLION_IDE__
@ -16,8 +17,12 @@ public:
const char *what() const noexcept override { return message.c_str(); }
};
#endif
#ifndef DG_HOST_ASSERT
#ifdef NVRTC_JIT_COMPILATION
#define DG_HOST_ASSERT(cond) ((void)0)
#else
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
@ -27,6 +32,7 @@ do { \
} \
} while (0)
#endif
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \

View File

@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate
from .template import generate
from .runtime import Runtime

View File

@ -3,13 +3,13 @@ import functools
import os
import re
import subprocess
import time
import uuid
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
@ -94,11 +94,11 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
def build(name: str, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a',
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code)
# Compile into a temporary SO file
so_path = f'{path}/kernel.so'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path,
src_path, '-o', tmp_cubin_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
interleave_ffma.process(tmp_cubin_path)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)

View File

@ -1,16 +1,18 @@
import ctypes
import os
import time
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch
from typing import Optional
from .template import map_ctype
from .utils import run_gemm
class Runtime:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.args = None
self.kernel = None
assert self.is_path_valid(self.path)
@ -21,29 +23,66 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cu', 'kernel.args', 'kernel.so']
files = ['kernel.cu', 'kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int:
# Load SO file
if self.lib is None or self.args is None:
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
self.args = eval(f.read())
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN
if self.lib is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
# Check args and launch
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
cargs = []
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel count: {res}")
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
if self.kernel is not None:
self.lib = lib
else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
cargs.append(map_ctype(arg))
raise Exception("Failed to find fp8 gemm kernel")
return_code = ctypes.c_int(0)
self.lib.launch(*cargs, ctypes.byref(return_code))
return return_code.value
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
)
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}")
class RuntimeCache:

View File

@ -1,111 +1,48 @@
import copy
import ctypes
import os
import torch
from typing import Any, Dict, Iterable, Tuple
from typing import Any, Dict
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
torch.int: 'torch.int',
torch.float: 'torch.float',
torch.bfloat16: 'torch.bfloat16',
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
#include <deep_gemm/nvrtc_std.cuh>
#else
# Type map for both Python API and source code usages
genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
#include <string>
#include <cuda.h>
#endif
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
if value.dtype == torch.int:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
if hasattr(value, 'cuda_stream'):
return ctypes.c_void_p(value.cuda_stream)
if isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, int):
return ctypes.c_int(value)
elif isinstance(value, float):
return ctypes.c_float(value)
return ctype_map[type(value)](value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
value_str = str(value)
if isinstance(value, bool):
value_str = value_str.lower()
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
namespace deep_gemm {{
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
}}
'''
# Debug print
if os.getenv('DG_JIT_DEBUG', None):

164
deep_gemm/jit/utils.py Normal file
View File

@ -0,0 +1,164 @@
import ctypes
from enum import Enum
from typing import Any, Dict, Tuple
import cuda.bindings.driver as cuda
import torch
class Layout(Enum):
RowMajor = 0
ColMajor = 1
class GemmType(Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
typename_map: Dict[Any, str] = {
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_map = {
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, "Only support 128 threads per math group"
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
tensor_dtype = typename_map[global_address.dtype]
res, tensor_map = cuda.cuTensorMapEncodeTiled(
tensor_dtype,
2, # tensor rank
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes,), # global strides
smem_dim,
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to encode tensor map: {res}")
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Make TMA aligned to 16 bytes
kAlignment = 16 / global_address.element_size()
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to set max dynamic shared memory size: {res}")
attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cuda.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
kernelValues = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)

View File

@ -3,40 +3,11 @@ import torch
from functools import lru_cache
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
constexpr auto kNumGroups = 1;
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated GEMM
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
gemm_t::run(out, rhs_scales, nullptr,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
@ -64,7 +35,8 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int:
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
smem_d = block_m * (block_n + block_n_padding) * 2
smem_a_per_stage = block_m * block_k
@ -78,7 +50,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += ceil_div(smem_scales_b * (1 if block_k %
block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
@ -97,9 +70,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
def fix_wave_saturate(x): return num_sms if x == 0 else x
def get_num_waves(bm, bn): return (ceil_div(
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
def get_last_wave_util(bm, bn): return fix_wave_saturate(
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
@ -107,7 +84,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# NOTES: the block sizes can not be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
num_waves, best_num_waves = get_num_waves(
block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
@ -124,7 +102,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
success |= block_m != best_block_m and block_n > best_block_n
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
best_block_m, best_block_n = (block_m, block_n) if success else (
best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None
# Always pick the longest one
@ -135,7 +114,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, 3)
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
best_smem_config = get_smem_config(
num_stages, k, best_block_m, best_block_n)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
@ -158,8 +138,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
num_min_sms = ceil_div(ceil_div(m, best_block_m) *
ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
@ -210,11 +192,42 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, k, 1, num_sms)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.Normal, lhs, m, k, block_m, block_k)
tensor_map_b = make_2d_tma_b_desc(
GemmType.Normal, rhs, k, n, block_k, block_n)
tensor_map_d = make_2d_tma_d_desc(
GemmType.Normal, smem_config[1], out, m, n, block_m, block_n)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.Normal, lhs_scales, m, k, block_m, block_k)
kwargs = {
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'NUM_GROUPS': 1,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
@ -223,14 +236,8 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
kwargs=kwargs
)
# Run the kernel
runtime(*args)
runtime(**best_keys, **kwargs)

View File

@ -1,41 +1,12 @@
import torch
from typing import Tuple
from ..jit.utils import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .gemm import get_best_configs, get_block_n_padding_for_smem_d
from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
# C++ code templates
includes = ('"deep_gemm/fp8_gemm.cuh"', )
template = """
using namespace deep_gemm;
// Templated args from Python JIT call
constexpr auto N = {N}, K = {K};
constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto BLOCK_K = 128;
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
constexpr auto kNumGroups = {NUM_GROUPS};
constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
auto tma_b_desc = gemm_t::make_2d_tma_b_desc(rhs);
auto tma_scales_a_desc = gemm_t::make_2d_tma_scales_a_desc(lhs_scales, m);
auto tma_d_desc = gemm_t::make_2d_tma_d_desc(out, m);
gemm_t::run(out, rhs_scales, grouped_layout,
m,
tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc,
stream, num_sms, smem_size);
"""
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
@ -87,13 +58,40 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
return
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedContiguous, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedContiguous, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedContiguous, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
@ -102,20 +100,13 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedContiguous'},
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int), ('num_groups', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
kwargs=kwargs,
)
# Run the kernel
runtime(*args)
runtime(**best_keys, **kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
@ -168,16 +159,44 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation
global includes, template
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_config[0])
runtime = jit_tuner.compile_and_tune(
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(
GemmType.GroupedMasked, lhs, m, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(
GemmType.GroupedMasked, rhs, k, n, block_k, block_n, num_groups)
tensor_map_d = make_2d_tma_d_desc(
GemmType.GroupedMasked, smem_config[1], out, m, n, block_m, block_n, num_groups)
tensor_map_scales_a = make_2d_tma_scales_a_desc(
GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m,
'BLOCK_K': block_k,
'GMEM_D': out,
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
}
runtime, best_keys = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
'SWIZZLE_D_MODE': smem_config[1],
@ -186,17 +205,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedMasked'},
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
('out', torch.bfloat16),
('grouped_layout', torch.int32), ('m', int),
('stream', torch.cuda.Stream), ('num_sms', int), ('smem_size', int)),
template=template,
args=args
kwargs=kwargs,
)
# Run the kernel
runtime(*args)
runtime(**best_keys, **kwargs)

View File

@ -3,15 +3,16 @@ import os
import torch
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
import cuda.bindings.driver as cuda
from ..jit import build, generate, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
@ -26,7 +27,7 @@ class JITTuner:
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert args is not None
assert kwargs is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
@ -34,30 +35,31 @@ class JITTuner:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
code = generate(**kwargs, **full_keys)
kernels.append((build(name, code), full_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(*args)
if return_code != 0:
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cuda.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
print(
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
torch.empty(int(256e6 // 4), dtype=torch.int,
device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(*args) == 0
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
@ -68,14 +70,16 @@ class JITTuner:
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
print(
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = best_runtime
return best_runtime
print(
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)
return best_runtime, best_keys
jit_tuner = JITTuner()