[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@@ -1,111 +1,48 @@
import copy
import ctypes
import os
import torch
from typing import Any, Dict, Iterable, Tuple
from typing import Any, Dict
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
torch.int: 'torch.int',
torch.float: 'torch.float',
torch.bfloat16: 'torch.bfloat16',
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
#include <deep_gemm/nvrtc_std.cuh>
#else
# Type map for both Python API and source code usages
genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
#include <string>
#include <cuda.h>
#endif
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
if value.dtype == torch.int:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
if hasattr(value, 'cuda_stream'):
return ctypes.c_void_p(value.cuda_stream)
if isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, int):
return ctypes.c_int(value)
elif isinstance(value, float):
return ctypes.c_float(value)
return ctype_map[type(value)](value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
value_str = str(value)
if isinstance(value, bool):
value_str = value_str.lower()
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
namespace deep_gemm {{
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
}}
'''
# Debug print
if os.getenv('DG_JIT_DEBUG', None):