[wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 08:08:40 +00:00
parent 891f35adf5
commit 27cd276e19
13 changed files with 581 additions and 323 deletions

View File

@@ -3,13 +3,13 @@ import functools
import os
import re
import subprocess
import time
import uuid
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .template import typename_map
runtime_cache = RuntimeCache()
@@ -94,11 +94,11 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path)
def build(name: str, arg_defs: tuple, code: str) -> Runtime:
def build(name: str, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a',
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@@ -121,31 +121,36 @@ def build(name: str, arg_defs: tuple, code: str) -> Runtime:
# Write the code
os.makedirs(path, exist_ok=True)
args_path = f'{path}/kernel.args'
src_path = f'{path}/kernel.cu'
put(args_path, ', '.join([f"('{arg_def[0]}', {typename_map[arg_def[1]]})" for arg_def in arg_defs]))
put(src_path, code)
# Compile into a temporary SO file
so_path = f'{path}/kernel.so'
tmp_so_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(so_path)}.so'
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_so_path,
src_path, '-o', tmp_cubin_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_so_path)
interleave_ffma.process(tmp_cubin_path)
# Atomic replace SO file
os.replace(tmp_so_path, so_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)