Merge pull request #1 from lucifer1004/nvrtc-compat

feat: add compat for older drivers and Windows
This commit is contained in:
Gabriel Wu 2025-04-23 15:06:10 +08:00 committed by GitHub
commit 6f0a17cb10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 110 additions and 56 deletions

View File

@ -2,6 +2,7 @@ import abc
import functools
import hashlib
import os
import platform
import re
import subprocess
import time
@ -13,7 +14,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
from .runtime import Runtime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache()
@ -26,22 +27,22 @@ def hash_to_hex(s: str) -> str:
@functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str:
return f'{os.path.dirname(os.path.abspath(__file__))}/../include'
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
# Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm'
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f:
with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read())
# Update `interleave_ffma.py`
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
md5.update(f.read())
return md5.hexdigest()[0:12]
@ -51,15 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_NVCC_COMPILER'):
paths.append(os.getenv('DG_NVCC_COMPILER'))
paths.append(f'{CUDA_HOME}/bin/nvcc')
nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc'
paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin))
# Try to find the first available NVCC compiler
least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(
os.popen(f'{path} --version').read())
command = [path, '--version']
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
match = version_pattern.search(result.stdout)
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@ -73,17 +78,17 @@ def get_default_user_dir():
path = os.getenv('DG_CACHE_DIR')
os.makedirs(path, exist_ok=True)
return path
return os.path.expanduser('~') + '/.deep_gemm'
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
@functools.lru_cache(maxsize=None)
def get_tmp_dir():
return f'{get_default_user_dir()}/tmp'
return os.path.join(get_default_user_dir(), 'tmp')
@functools.lru_cache(maxsize=None)
def get_cache_dir():
return f'{get_default_user_dir()}/cache'
return os.path.join(get_default_user_dir(), 'cache')
def make_tmp_dir():
@ -92,9 +97,11 @@ def make_tmp_dir():
return tmp_dir
def put(path, data, is_binary=False):
def put(path, data):
is_binary = isinstance(data, bytes)
# Write and do POSIX atomic replace
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}'
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
@ -108,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod
@abc.abstractmethod
def compile(cls, name: str, src_path: str, target_path: str):
def compile(cls, name: str, code: str, target_path: str) -> str:
pass
@staticmethod
@ -118,7 +125,7 @@ class Compiler(abc.ABC):
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940']
'--diag-suppress=39,161,174,177,940']
@staticmethod
def include_dirs() -> List[str]:
@ -135,7 +142,7 @@ class Compiler(abc.ABC):
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}'
path = os.path.join(get_cache_dir(), name)
# Check runtime cache or file system hit
global runtime_cache
@ -144,17 +151,13 @@ class Compiler(abc.ABC):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin'
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin'
os.makedirs(path, exist_ok=True)
cubin_path = os.path.join(path, 'kernel.cubin')
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path)
kernel_name = cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
@ -164,12 +167,16 @@ class Compiler(abc.ABC):
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Store kernel name
put(f'{tmp_cubin_path}.name', kernel_name)
# Atomic replace CU file
# Atomic replace files
os.replace(tmp_cubin_path, cubin_path)
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return
runtime_cache[path] = Runtime(path)
runtime_cache[path] = Runtime(path, kernel_name)
return runtime_cache[path]
@ -182,29 +189,47 @@ class NvccCompiler(Compiler):
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3',
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
if platform.system() != 'Windows':
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
else:
cxx_flags = ['/O2', '/std:c++20']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
def compile(cls, name: str, code: str, target_path: str) -> str:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
put(src_path, code)
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
if os.getenv('DG_JIT_DEBUG', None):
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
class NvrtcCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get actual NVRTC version, use bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor)
@staticmethod
@ -218,19 +243,27 @@ class NvrtcCompiler(Compiler):
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}']
base_flags += ['--pch']
if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true']
return base_flags
@classmethod
def compile(cls, name: str, src_path: str, target_path: str):
code_bytes = open(src_path, 'rb').read()
def compile(cls, name: str, code: str, target_path: str) -> str:
code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
program, bytes(kernel_name, 'utf-8'))[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to add name expression: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0]
@ -249,6 +282,12 @@ class NvrtcCompiler(Compiler):
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}")
# NVRTC can directly get the lowered name
res, lowered_name = nvrtc.nvrtcGetLoweredName(
program, bytes(kernel_name, 'utf-8'))
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get lowered name: {res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}")
@ -258,12 +297,14 @@ class NvrtcCompiler(Compiler):
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True)
put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}")
return lowered_name.decode('utf-8')
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:

View File

@ -1,16 +1,36 @@
import os
import platform
import time
import subprocess
from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime:
def __init__(self, path: str) -> None:
def __init__(self, path: str, kernel_name: str) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
assert self.is_path_valid(self.path)
@ -21,34 +41,24 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cu', 'kernel.cubin']
files = ['kernel.cubin', 'kernel.cubin.name']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN
if self.lib is None:
if self.kernel is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel count: {res}")
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
raise Exception(f"Failed to get kernel: {res}")
self.kernel = kernel
if self.kernel is not None:
self.lib = lib
else:
@ -95,7 +105,8 @@ class RuntimeCache:
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path)
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime
return runtime
return None

View File

@ -22,7 +22,9 @@ def generate(**kwargs: Dict[str, Any]) -> str:
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
namespace deep_gemm {{
using namespace deep_gemm;
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
@ -41,7 +43,7 @@ __global__ void dummy_kernel() {{
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
}}
#endif
'''
# Debug print