mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .template import cpp_format, generate
|
||||
from .template import generate
|
||||
from .runtime import Runtime
|
||||
|
||||
@@ -3,13 +3,13 @@ import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from typing import Tuple
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache
|
||||
from .template import typename_map
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@@ -94,11 +94,11 @@ def put(path, data, is_binary=False):
|
||||
os.replace(tmp_file_path, path)
|
||||
|
||||
|
||||
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
# Compiler flags
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
|
||||
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
|
||||
'--diag-suppress=39,174,177,940']
|
||||
@@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
|
||||
|
||||
# Write the code
|
||||
os.makedirs(path, exist_ok=True)
|
||||
args_path = f'{path}/kernel.args'
|
||||
src_path = f'{path}/kernel.cu'
|
||||
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
|
||||
put(src_path, code)
|
||||
|
||||
# Compile into a temporary SO file
|
||||
so_path = f'{path}/kernel.so'
|
||||
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
|
||||
# Compile into a temporary CU file
|
||||
cubin_path = f'{path}/kernel.cubin'
|
||||
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
|
||||
|
||||
# Compile
|
||||
command = [get_nvcc_compiler()[0],
|
||||
src_path, '-o', tmp_so_path,
|
||||
src_path, '-o', tmp_cubin_path,
|
||||
*flags,
|
||||
*[f'-I{d}' for d in include_dirs]]
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
start_time = time.time()
|
||||
return_code = subprocess.check_call(command)
|
||||
end_time = time.time()
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# Print elapsed time if debug is enabled
|
||||
elapsed_time = end_time - start_time
|
||||
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
|
||||
|
||||
# Interleave FFMA reuse
|
||||
if enable_sass_opt:
|
||||
interleave_ffma.process(tmp_so_path)
|
||||
interleave_ffma.process(tmp_cubin_path)
|
||||
|
||||
# Atomic replace SO file
|
||||
os.replace(tmp_so_path, so_path)
|
||||
# Atomic replace CU file
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path)
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import ctypes
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
self.kernel = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@@ -21,29 +23,66 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
files = ['kernel.cu', 'kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.lib is None:
|
||||
start_time = time.time_ns()
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel count: {res}")
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to enumerate kernels: {res}")
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel name: {res}")
|
||||
if b"fp8" in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
raise Exception("Failed to find fp8 gemm kernel")
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
self.kernel,
|
||||
kwargs['NUM_TMA_MULTICAST'],
|
||||
kwargs['M'],
|
||||
kwargs['BLOCK_M'],
|
||||
kwargs['GMEM_D'],
|
||||
kwargs['SCALES_B'],
|
||||
kwargs['GROUPED_LAYOUT'],
|
||||
kwargs['NUM_SMS'],
|
||||
kwargs['SMEM_SIZE'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
kwargs['STREAM'],
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
res = cuda.cuLibraryUnload(self.lib)[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
|
||||
@@ -1,111 +1,48 @@
|
||||
import copy
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict, Iterable, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
# Name map for Python `eval`
|
||||
typename_map: Dict[Any, str] = {
|
||||
**{t: t.__name__ for t in (bool, int, float)},
|
||||
torch.int: 'torch.int',
|
||||
torch.float: 'torch.float',
|
||||
torch.bfloat16: 'torch.bfloat16',
|
||||
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
|
||||
torch.cuda.Stream: 'torch.cuda.Stream',
|
||||
}
|
||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
|
||||
# `ctype` map for Python casting
|
||||
ctype_map: Dict[Any, Any] = {
|
||||
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
|
||||
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
|
||||
}
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
|
||||
#else
|
||||
|
||||
# Type map for both Python API and source code usages
|
||||
genc_map = {
|
||||
bool: ('bool', 'bool'),
|
||||
int: ('int', 'int'),
|
||||
float: ('float', 'float'),
|
||||
torch.int: ('void*', 'int*'),
|
||||
torch.float: ('void*', 'float*'),
|
||||
torch.bfloat16: ('void*', '__nv_bfloat16*'),
|
||||
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
|
||||
torch.cuda.Stream: ('void*', 'cudaStream_t'),
|
||||
}
|
||||
#include <string>
|
||||
#include <cuda.h>
|
||||
|
||||
#endif
|
||||
|
||||
def map_ctype(value: Any) -> Any:
|
||||
if hasattr(value, 'data_ptr'):
|
||||
if value.dtype == torch.int:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.bfloat16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float16:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
elif value.dtype == torch.float8_e4m3fn:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
else:
|
||||
return ctypes.c_void_p(value.data_ptr())
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
if hasattr(value, 'cuda_stream'):
|
||||
return ctypes.c_void_p(value.cuda_stream)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return ctypes.c_bool(value)
|
||||
elif isinstance(value, int):
|
||||
return ctypes.c_int(value)
|
||||
elif isinstance(value, float):
|
||||
return ctypes.c_float(value)
|
||||
|
||||
return ctype_map[type(value)](value)
|
||||
|
||||
|
||||
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
|
||||
# We don't use `str.format` because it's not safe for C++ {} braces
|
||||
new_template = copy.deepcopy(template)
|
||||
for key, value in keys.items():
|
||||
value_str = str(value)
|
||||
if isinstance(value, bool):
|
||||
value_str = value_str.lower()
|
||||
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
|
||||
return new_template
|
||||
|
||||
|
||||
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
|
||||
# Common prefix
|
||||
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
|
||||
|
||||
# Includes
|
||||
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
|
||||
preload_package_includes = ['"cutlass/cutlass.h"']
|
||||
|
||||
assert isinstance(includes, list) or isinstance(includes, tuple)
|
||||
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
|
||||
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
|
||||
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
|
||||
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
|
||||
|
||||
# Function signature
|
||||
raw = '__raw_'
|
||||
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
|
||||
code += f'extern "C" void launch('
|
||||
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
|
||||
code += ') {\n'
|
||||
|
||||
# Cast raw types
|
||||
code += ' // Cast raw types (if needed)\n'
|
||||
for arg_name, arg_type in arg_defs:
|
||||
if genc_map[arg_type][0] != genc_map[arg_type][1]:
|
||||
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
|
||||
|
||||
# Function body
|
||||
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
|
||||
|
||||
# End the function
|
||||
code += '}\n\n'
|
||||
namespace deep_gemm {{
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
}}
|
||||
'''
|
||||
|
||||
# Debug print
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
|
||||
164
deep_gemm/jit/utils.py
Normal file
164
deep_gemm/jit/utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import ctypes
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
|
||||
class Layout(Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
GroupedMasked = 2
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
0: 'Normal',
|
||||
1: 'GroupedContiguous',
|
||||
2: 'GroupedMasked',
|
||||
}[self.value]
|
||||
|
||||
|
||||
typename_map: Dict[Any, str] = {
|
||||
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
|
||||
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
|
||||
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
|
||||
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
|
||||
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
|
||||
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
|
||||
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
|
||||
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
|
||||
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
|
||||
}
|
||||
|
||||
swizzle_map = {
|
||||
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
|
||||
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
|
||||
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
|
||||
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
|
||||
}
|
||||
|
||||
def get_num_math_warpgroups(block_m: int) -> int:
|
||||
return 1 if block_m == 64 else 2
|
||||
|
||||
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
|
||||
assert num_math_threads_per_group == 128, "Only support 128 threads per math group"
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
|
||||
tensor_dtype = typename_map[global_address.dtype]
|
||||
res, tensor_map = cuda.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2, # tensor rank
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes,), # global strides
|
||||
smem_dim,
|
||||
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
|
||||
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
|
||||
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
|
||||
)
|
||||
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to encode tensor map: {res}")
|
||||
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
|
||||
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
kAlignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to set max dynamic shared memory size: {res}")
|
||||
|
||||
attr_val = cuda.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cuda.CUlaunchAttribute()
|
||||
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
||||
attr.value = attr_val
|
||||
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
|
||||
kernelValues = (
|
||||
gmem_d.data_ptr(),
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
)
|
||||
kernelTypes = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)
|
||||
Reference in New Issue
Block a user