Refactor JIT compilation (+NVRTC support) (#94)

* [wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* fix: compiler version

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: compat for old drivers

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: save kernel name to file

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: fix win compat

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* fix: windows compat

Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: make API more general

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: drop support for CUDA<12.3

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* doc: update README

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* Some lints and refactor

* Refactor runtime

* Several fixes

* Refactor environment variables

* Code format

* Add a TODO

* Compatible with CUDA 12.3

* Fix indent

* Fix typing

* Drop support for Windows

* Add a TODO

---------

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
Gabriel Wu
2025-05-07 11:38:14 +08:00
committed by GitHub
parent d374456787
commit bfe983c4c2
19 changed files with 909 additions and 660 deletions

View File

@@ -1,3 +1,2 @@
from .compiler import get_nvcc_compiler, build
from .template import cpp_format, generate
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
from .runtime import Runtime

View File

@@ -1,15 +1,18 @@
import hashlib
import functools
import hashlib
import os
import re
import subprocess
import time
import uuid
from typing import List, Tuple, Type
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
@@ -22,21 +25,22 @@ def hash_to_hex(s: str) -> str:
@functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str:
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
# Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm'
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
# Update include directories
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f:
with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read())
# Update `interleave_ffma.py`
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
md5.update(f.read())
return md5.hexdigest()[0:12]
@@ -44,16 +48,20 @@ def get_deep_gemm_version() -> str:
@functools.lru_cache(maxsize=None)
def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_NVCC_COMPILER'):
paths.append(os.getenv('DG_NVCC_COMPILER'))
paths.append(f'{CUDA_HOME}/bin/nvcc')
if os.getenv('DG_JIT_NVCC_COMPILER'):
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
# Try to find the first available NVCC compiler
least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(os.popen(f'{path} --version').read())
command = [path, '--version']
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
match = version_pattern.search(result.stdout)
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@@ -63,21 +71,21 @@ def get_nvcc_compiler() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def get_default_user_dir():
if 'DG_CACHE_DIR' in os.environ:
path = os.getenv('DG_CACHE_DIR')
if 'DG_JIT_CACHE_DIR' in os.environ:
path = os.getenv('DG_JIT_CACHE_DIR')
os.makedirs(path, exist_ok=True)
return path
return os.path.expanduser('~') + '/.deep_gemm'
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
@functools.lru_cache(maxsize=None)
def get_tmp_dir():
return f'{get_default_user_dir()}/tmp'
return os.path.join(get_default_user_dir(), 'tmp')
@functools.lru_cache(maxsize=None)
def get_cache_dir():
return f'{get_default_user_dir()}/cache'
return os.path.join(get_default_user_dir(), 'cache')
def make_tmp_dir():
@@ -86,67 +94,192 @@ def make_tmp_dir():
return tmp_dir
def put(path, data, is_binary=False):
def put(path, data):
# Write and do POSIX atomic replace
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
class Compiler:
@classmethod
def signature(cls) -> str:
pass
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
@staticmethod
def __version__() -> Tuple[int, int]:
pass
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
pass
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code)
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,161,174,177,940']
# Compile into a temporary SO file
so_path = f'{path}/kernel.so'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
@staticmethod
def include_dirs() -> List[str]:
return [get_jit_include_dir()]
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
@classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
# Compiler flags
flags = cls.flags()
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
# Build signature
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Check runtime cache or file system hit
global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build')
return cached_runtime
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]
# Compile into a temporary CU file
os.makedirs(path, exist_ok=True)
cubin_path = os.path.join(path, 'kernel.cubin')
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time()
cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Atomic replace files
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime = runtime_cls(path)
runtime_cache[path] = runtime
return runtime
class NVCCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvcc+{cls.__version__()}'
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
put(src_path, code)
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with command {command}')
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
assert False, f'Failed to compile {src_path}'
class NVRTCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get the actual NVRTC version, use cuda-bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvrtc+{cls.__version__()}'
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
@classmethod
def flags(cls) -> List[str]:
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
flags += ['--pch']
if int(os.getenv('DG_JIT_DEBUG', 0)):
flags += ['--pch-verbose=true']
return flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
# Create program
code_bytes = bytes(code, 'utf-8')
result, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
# Compile
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with options: {options}')
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
# Print compiler log
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
log_bytes = bytes(log_size)
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
print(f'Compiler log: {log_bytes.decode("utf-8")}')
# Exit if failed
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
# Create CUBIN
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
cubin_bytes = bytes(cubin_size)
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
# Write into the file system
put(target_path, cubin_bytes)
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls=runtime_cls)

View File

@@ -37,7 +37,7 @@ def extract_ffma(sass):
collected.append((f'{arch_name}::{func_name}', current))
current = []
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f'Found {len(collected)} FFMA segments')
return collected
@@ -100,7 +100,7 @@ def modify_segment(m, name, ffma_lines):
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset
@@ -118,7 +118,7 @@ def modify_segment(m, name, ffma_lines):
def process(path):
if os.getenv('DG_PRINT_REG_REUSE', None):
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)

View File

@@ -1,17 +1,18 @@
import ctypes
import os
import torch
from typing import Optional
import subprocess
import time
import cuda.bindings.driver as cbd
from .template import map_ctype
from typing import List, Optional, Type
from torch.utils.cpp_extension import CUDA_HOME
class Runtime:
def __init__(self, path: str) -> None:
def __init__(self, path: str, args: List[str] = None) -> None:
self.path = path
self.lib = None
self.args = None
self.kernel = None
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
@@ -21,46 +22,69 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cu', 'kernel.args', 'kernel.so']
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, *args) -> int:
# Load SO file
if self.lib is None or self.args is None:
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
self.args = eval(f.read())
@staticmethod
def generate(**kwargs) -> str:
raise NotImplemented
# Check args and launch
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
cargs = []
for arg, (name, dtype) in zip(args, self.args):
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
else:
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
cargs.append(map_ctype(arg))
@staticmethod
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
raise NotImplemented
return_code = ctypes.c_int(0)
self.lib.launch(*cargs, ctypes.byref(return_code))
return return_code.value
def __call__(self, **kwargs) -> cbd.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
# Load CUBIN
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
# Extract the kernel name
# TODO: use `cuda-bindings` API to do this (requires at least 12.8)
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1e6
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
def __del__(self) -> None:
if self.lib is not None:
res = cbd.cuLibraryUnload(self.lib)[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to unload library {self.path}: {res}')
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
def __setitem__(self, path: str, runtime: Runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path)
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime

View File

@@ -1,114 +0,0 @@
import copy
import ctypes
import os
import torch
from typing import Any, Dict, Iterable, Tuple
# Name map for Python `eval`
typename_map: Dict[Any, str] = {
**{t: t.__name__ for t in (bool, int, float)},
torch.int: 'torch.int',
torch.float: 'torch.float',
torch.bfloat16: 'torch.bfloat16',
torch.float8_e4m3fn: 'torch.float8_e4m3fn',
torch.cuda.Stream: 'torch.cuda.Stream',
}
# `ctype` map for Python casting
ctype_map: Dict[Any, Any] = {
**{t: getattr(ctypes, f'c_{t.__name__}') for t in (bool, int, float)},
**{t: ctypes.c_void_p for t in (torch.int, torch.float, torch.bfloat16, torch.float8_e4m3fn, torch.cuda.Stream)},
}
# Type map for both Python API and source code usages
genc_map = {
bool: ('bool', 'bool'),
int: ('int', 'int'),
float: ('float', 'float'),
torch.int: ('void*', 'int*'),
torch.float: ('void*', 'float*'),
torch.bfloat16: ('void*', '__nv_bfloat16*'),
torch.float8_e4m3fn: ('void*', '__nv_fp8_e4m3*'),
torch.cuda.Stream: ('void*', 'cudaStream_t'),
}
def map_ctype(value: Any) -> Any:
if hasattr(value, 'data_ptr'):
if value.dtype == torch.int:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.bfloat16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float16:
return ctypes.c_void_p(value.data_ptr())
elif value.dtype == torch.float8_e4m3fn:
return ctypes.c_void_p(value.data_ptr())
else:
return ctypes.c_void_p(value.data_ptr())
if hasattr(value, 'cuda_stream'):
return ctypes.c_void_p(value.cuda_stream)
if isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, int):
return ctypes.c_int(value)
elif isinstance(value, float):
return ctypes.c_float(value)
return ctype_map[type(value)](value)
def cpp_format(template: str, keys: Dict[str, Any]) -> str:
# We don't use `str.format` because it's not safe for C++ {} braces
new_template = copy.deepcopy(template)
for key, value in keys.items():
value_str = str(value)
if isinstance(value, bool):
value_str = value_str.lower()
new_template = new_template.replace(f'{{{key}}}', f'{value_str}')
return new_template
def generate(includes: Iterable[str], arg_defs: Iterable[Tuple], body: str) -> str:
# Common prefix
code = '// DeepGEMM auto-generated JIT CUDA source file\n\n'
# Includes
preload_sys_includes = ['<cuda.h>', '<cuda_fp8.h>', '<cuda_runtime.h>', '<iostream>']
preload_package_includes = ['"cutlass/cutlass.h"']
assert isinstance(includes, list) or isinstance(includes, tuple)
sys_includes = sorted(list(set(preload_sys_includes + [include for include in includes if include.startswith('<')])))
package_includes = sorted(list(set(preload_package_includes + [include for include in includes if include.startswith('"')])))
code += '\n'.join(f'#include {include}' for include in sys_includes) + '\n\n'
code += '\n'.join(f'#include {include}' for include in package_includes) + '\n\n'
# Function signature
raw = '__raw_'
get_def = lambda n, t: f'{genc_map[t][0]} ' + (raw if genc_map[t][0] != genc_map[t][1] else '') + n
code += f'extern "C" void launch('
code += ', '.join([get_def(*arg_def) for arg_def in arg_defs] + ['int& __return_code', ])
code += ') {\n'
# Cast raw types
code += ' // Cast raw types (if needed)\n'
for arg_name, arg_type in arg_defs:
if genc_map[arg_type][0] != genc_map[arg_type][1]:
code += f' auto {arg_name} = reinterpret_cast<{genc_map[arg_type][1]}>({raw}{arg_name});\n'
# Function body
code += '\n'.join([((' ' if line else '') + line) for line in body.split('\n')])
# End the function
code += '}\n\n'
# Debug print
if os.getenv('DG_JIT_DEBUG', None):
print(f'Generated code:\n{code}')
return code