refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 10:17:52 +00:00
parent 27cd276e19
commit c14cad0c06
5 changed files with 237 additions and 284 deletions

View File

@@ -1,12 +1,16 @@
import hashlib
import abc
import functools
import hashlib
import os
import re
import subprocess
import time
import uuid
from typing import List, Tuple
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from typing import Tuple
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
@@ -29,7 +33,8 @@ def get_jit_include_dir() -> str:
def get_deep_gemm_version() -> str:
# Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm'
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f:
@@ -53,7 +58,8 @@ def get_nvcc_compiler() -> Tuple[str, str]:
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(os.popen(f'{path} --version').read())
match = version_pattern.search(
os.popen(f'{path} --version').read())
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@@ -94,64 +100,173 @@ def put(path, data, is_binary=False):
os.replace(tmp_file_path, path)
def build(name: str, code: str) -> Runtime:
# Compiler flags
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
nvcc_flags = [f'-std=c++{cpp_standard}', '-shared', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'-gencode=arch=compute_90a,code=sm_90a', '-cubin',
'--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
cxx_flags = ['-fPIC', '-O3', '-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
flags = [*nvcc_flags, f'--compiler-options={",".join(cxx_flags)}']
include_dirs = [get_jit_include_dir()]
class Compiler(abc.ABC):
@staticmethod
@abc.abstractmethod
def __version__() -> Tuple[int, int]:
pass
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
@classmethod
@abc.abstractmethod
def compile(cls, name: str, src_path: str, target_path: str):
pass
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
@staticmethod
def include_dirs() -> List[str]:
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str) -> Runtime:
# Compiler flags
flags = cls.flags()
include_dirs = cls.include_dirs()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
print(
f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
class NvccCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
major, minor, _ = map(int, cuda.bindings.__version__.split('.'))
return (major, minor)
# Compile
command = [get_nvcc_compiler()[0],
src_path, '-o', tmp_cubin_path,
*flags,
*[f'-I{d}' for d in include_dirs]]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
start_time = time.time()
return_code = subprocess.check_call(command)
end_time = time.time()
assert return_code == 0, f'Failed to compile {src_path}'
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3',
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
# Print elapsed time if debug is enabled
elapsed_time = end_time - start_time
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
# Atomic replace CU file
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime_cache[path] = Runtime(path)
return runtime_cache[path]
class NvrtcCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return (major, minor)
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')]
@classmethod
def flags(cls) -> List[str]:
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
code_bytes = open(src_path, 'rb').read()
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log size: {res}")
log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log: {res}")
log_str = log_bytes.decode('utf-8')
print(log_str)
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}")
cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}")
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code)
else:
return NvccCompiler.build(name, code)