[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate
from .template import generate
from .runtime import Runtime

View File

@@ -3,13 +3,13 @@ import functools
import os
import re
import subprocess
import time
import uuid
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
@@ -94,11 +94,11 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
def build(name: str, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a',
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code)
# Compile into a temporary SO file
so_path = f'{path}/kernel.so'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path,
src_path, '-o', tmp_cubin_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
interleave_ffma.process(tmp_cubin_path)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)

View File

@@ -1,16 +1,18 @@
import ctypes
import os
import time
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
import torch
from typing import Optional
from .template import map_ctype
from .utils import run_gemm
class Runtime:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.args = None
self.kernel = None
assert self.is_path_valid(self.path)
@@ -21,29 +23,66 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cu', 'kernel.args', 'kernel.so']
files = ['kernel.cu', 'kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int:
# Load SO file
if self.lib is None or self.args is None:
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
self.args = eval(f.read())
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN
if self.lib is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
# Check args and launch
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
cargs = []
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel count: {res}")
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
if self.kernel is not None:
self.lib = lib
else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
cargs.append(map_ctype(arg))
raise Exception("Failed to find fp8 gemm kernel")
return_code = ctypes.c_int(0)
self.lib.launch(*cargs, ctypes.byref(return_code))
return return_code.value
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
)
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}")
class RuntimeCache:

View File

@@ -1,111 +1,48 @@
import copy
import ctypes
import os
import torch
from typing import Any, Dict, Iterable, Tuple
from typing import Any, Dict
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
torch.int: 'torch.int',
torch.float: 'torch.float',
torch.bfloat16: 'torch.bfloat16',
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
#include <deep_gemm/nvrtc_std.cuh>
#else
# Type map for both Python API and source code usages
genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
#include <string>
#include <cuda.h>
#endif
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
if value.dtype == torch.int:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
if hasattr(value, 'cuda_stream'):
return ctypes.c_void_p(value.cuda_stream)
if isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, int):
return ctypes.c_int(value)
elif isinstance(value, float):
return ctypes.c_float(value)
return ctype_map[type(value)](value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
value_str = str(value)
if isinstance(value, bool):
value_str = value_str.lower()
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
namespace deep_gemm {{
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
}}
'''
# Debug print
if os.getenv('DG_JIT_DEBUG', None):

164
deep_gemm/jit/utils.py Normal file
View File

@@ -0,0 +1,164 @@
import ctypes
from enum import Enum
from typing import Any, Dict, Tuple
import cuda.bindings.driver as cuda
import torch
class Layout(Enum):
RowMajor = 0
ColMajor = 1
class GemmType(Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
typename_map: Dict[Any, str] = {
torch.int8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cuda.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_map = {
128: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
64: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
32: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
0: cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, "Only support 128 threads per math group"
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuuint64_t, cuda.cuuint64_t], stride_in_bytes: cuda.cuuint64_t, smem_dim: Tuple[cuda.cuuint32_t, cuda.cuuint32_t], swizzle_type: cuda.CUtensorMapSwizzle) -> cuda.CUtensorMap:
tensor_dtype = typename_map[global_address.dtype]
res, tensor_map = cuda.cuTensorMapEncodeTiled(
tensor_dtype,
2, # tensor rank
global_address.data_ptr(),
gmem_dim,
(stride_in_bytes,), # global strides
smem_dim,
(cuda.cuuint32_t(1), cuda.cuuint32_t(1)), # element strides
cuda.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cuda.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cuda.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to encode tensor map: {res}")
return tensor_map
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout, gmem_rows: int, gmem_cols: int, smem_rows: int, smem_cols: int, swizzle_type: cuda.CUtensorMapSwizzle = cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cuda.CUtensorMap:
if layout == Layout.RowMajor:
gmem_dim = (cuda.cuuint64_t(gmem_cols), cuda.cuuint64_t(gmem_rows))
smem_dim = (cuda.cuuint32_t(smem_cols), cuda.cuuint32_t(smem_rows))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_cols * global_address.element_size()), smem_dim, swizzle_type)
else:
gmem_dim = (cuda.cuuint64_t(gmem_rows), cuda.cuuint64_t(gmem_cols))
smem_dim = (cuda.cuuint32_t(smem_rows), cuda.cuuint32_t(smem_cols))
return make_2d_tma_copy_desc(global_address, gmem_dim, cuda.cuuint64_t(gmem_rows * global_address.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, block_m, block_k)
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_k: int, shape_n: int, block_k: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, swizzle_mode: int, global_address: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Swizzling requires the inner box dim less or equal than `kSwizzleDMode`
# bytes So `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(global_address, Layout.RowMajor, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(), swizzle_map[swizzle_mode])
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cuda.CUtensorMap:
# Make TMA aligned to 16 bytes
kAlignment = 16 / global_address.element_size()
shape_m = (shape_m + kAlignment - 1) // kAlignment * kAlignment
return make_2d_tma_desc(global_address, Layout.ColMajor, shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), block_m, 1, cuda.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, grouped_layout: torch.Tensor, num_sms: int, smem_size: int, tensor_map_a: cuda.CUtensorMap, tensor_map_b: cuda.CUtensorMap, tensor_map_scales_a: cuda.CUtensorMap, tensor_map_d: cuda.CUtensorMap, stream: cuda.CUstream) -> cuda.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to set max dynamic shared memory size: {res}")
attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cuda.CUlaunchAttribute()
attr.id = cuda.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cuda.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = num_sms
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = smem_size
config.hStream = stream
kernelValues = (
gmem_d.data_ptr(),
scales_b.data_ptr(),
grouped_layout.data_ptr(),
shape_m,
tensor_map_a,
tensor_map_b,
tensor_map_scales_a,
tensor_map_d,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)