Merge pull request #1 from lucifer1004/nvrtc-compat

feat: add compat for older drivers and Windows
This commit is contained in:
Gabriel Wu
2025-04-23 15:06:10 +08:00
committed by GitHub
3 changed files with 110 additions and 56 deletions

View File

@@ -2,6 +2,7 @@ import abc
import functools import functools
import hashlib import hashlib
import os import os
import platform
import re import re
import subprocess import subprocess
import time import time
@@ -13,7 +14,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, RuntimeCache from .runtime import Runtime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache() runtime_cache = RuntimeCache()
@@ -26,22 +27,22 @@ def hash_to_hex(s: str) -> str:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str: def get_jit_include_dir() -> str:
return f'{os.path.dirname(os.path.abspath(__file__))}/../include' return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str: def get_deep_gemm_version() -> str:
# Update include directories # Update include directories
include_dir = f'{get_jit_include_dir()}/deep_gemm' include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists( assert os.path.exists(
include_dir), f'Cannot find GEMM include directory {include_dir}' include_dir), f'Cannot find GEMM include directory {include_dir}'
md5 = hashlib.md5() md5 = hashlib.md5()
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))): for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(f'{include_dir}/{filename}', 'rb') as f: with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read()) md5.update(f.read())
# Update `interleave_ffma.py` # Update `interleave_ffma.py`
with open(f'{os.path.dirname(os.path.realpath(__file__))}/interleave_ffma.py', 'rb') as f: with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
md5.update(f.read()) md5.update(f.read())
return md5.hexdigest()[0:12] return md5.hexdigest()[0:12]
@@ -51,15 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]:
paths = [] paths = []
if os.getenv('DG_NVCC_COMPILER'): if os.getenv('DG_NVCC_COMPILER'):
paths.append(os.getenv('DG_NVCC_COMPILER')) paths.append(os.getenv('DG_NVCC_COMPILER'))
paths.append(f'{CUDA_HOME}/bin/nvcc')
nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc'
paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin))
# Try to find the first available NVCC compiler # Try to find the first available NVCC compiler
least_version_required = '12.3' least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)') version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths: for path in paths:
if os.path.exists(path): if os.path.exists(path):
match = version_pattern.search( command = [path, '--version']
os.popen(f'{path} --version').read()) result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
match = version_pattern.search(result.stdout)
version = match.group(1) version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}' assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}' assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@@ -73,17 +78,17 @@ def get_default_user_dir():
path = os.getenv('DG_CACHE_DIR') path = os.getenv('DG_CACHE_DIR')
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
return os.path.expanduser('~') + '/.deep_gemm' return os.path.join(os.path.expanduser('~'), '.deep_gemm')
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_tmp_dir(): def get_tmp_dir():
return f'{get_default_user_dir()}/tmp' return os.path.join(get_default_user_dir(), 'tmp')
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_cache_dir(): def get_cache_dir():
return f'{get_default_user_dir()}/cache' return os.path.join(get_default_user_dir(), 'cache')
def make_tmp_dir(): def make_tmp_dir():
@@ -92,9 +97,11 @@ def make_tmp_dir():
return tmp_dir return tmp_dir
def put(path, data, is_binary=False): def put(path, data):
is_binary = isinstance(data, bytes)
# Write and do POSIX atomic replace # Write and do POSIX atomic replace
tmp_file_path = f'{make_tmp_dir()}/file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}' tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if is_binary else 'w') as f: with open(tmp_file_path, 'wb' if is_binary else 'w') as f:
f.write(data) f.write(data)
os.replace(tmp_file_path, path) os.replace(tmp_file_path, path)
@@ -108,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
pass pass
@staticmethod @staticmethod
@@ -118,7 +125,7 @@ class Compiler(abc.ABC):
'--ptxas-options=--register-usage-level=10' + '--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases # Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,174,177,940'] '--diag-suppress=39,161,174,177,940']
@staticmethod @staticmethod
def include_dirs() -> List[str]: def include_dirs() -> List[str]:
@@ -135,7 +142,7 @@ class Compiler(abc.ABC):
os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0 os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) == 0
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
name = f'kernel.{name}.{hash_to_hex(signature)}' name = f'kernel.{name}.{hash_to_hex(signature)}'
path = f'{get_cache_dir()}/{name}' path = os.path.join(get_cache_dir(), name)
# Check runtime cache or file system hit # Check runtime cache or file system hit
global runtime_cache global runtime_cache
@@ -144,17 +151,13 @@ class Compiler(abc.ABC):
print(f'Using cached JIT runtime {name} during build') print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path] return runtime_cache[path]
# Write the code
os.makedirs(path, exist_ok=True)
src_path = f'{path}/kernel.cu'
put(src_path, code)
# Compile into a temporary CU file # Compile into a temporary CU file
cubin_path = f'{path}/kernel.cubin' os.makedirs(path, exist_ok=True)
tmp_cubin_path = f'{make_tmp_dir()}/nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin' cubin_path = os.path.join(path, 'kernel.cubin')
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time() start_time = time.time()
cls.compile(name, src_path, tmp_cubin_path) kernel_name = cls.compile(name, code, tmp_cubin_path)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
@@ -165,11 +168,15 @@ class Compiler(abc.ABC):
if enable_sass_opt: if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path) interleave_ffma.process(tmp_cubin_path)
# Atomic replace CU file # Store kernel name
put(f'{tmp_cubin_path}.name', kernel_name)
# Atomic replace files
os.replace(tmp_cubin_path, cubin_path) os.replace(tmp_cubin_path, cubin_path)
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return # Put cache and return
runtime_cache[path] = Runtime(path) runtime_cache[path] = Runtime(path, kernel_name)
return runtime_cache[path] return runtime_cache[path]
@@ -182,29 +189,47 @@ class NvccCompiler(Compiler):
@classmethod @classmethod
def flags(cls) -> List[str]: def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3', if platform.system() != 'Windows':
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts'] cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
else:
cxx_flags = ['/O2', '/std:c++20']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a', '-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}'] f'--compiler-options={",".join(cxx_flags)}']
@classmethod @classmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
put(src_path, code)
command = [get_nvcc_compiler()[0], command = [get_nvcc_compiler()[0],
src_path, '-o', target_path, src_path, '-o', target_path,
*cls.flags()] *cls.flags()]
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False): if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}') print(f'Compiling JIT runtime {name} with command {command}')
return_code = subprocess.check_call(command) result = subprocess.run(command, stdout=subprocess.PIPE,
assert return_code == 0, f'Failed to compile {src_path}' stderr=subprocess.PIPE, text=True)
if os.getenv('DG_JIT_DEBUG', None):
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
class NvrtcCompiler(Compiler): class NvrtcCompiler(Compiler):
@staticmethod @staticmethod
def __version__() -> Tuple[int, int]: def __version__() -> Tuple[int, int]:
major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get actual NVRTC version, use bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor) return (major, minor)
@staticmethod @staticmethod
@@ -218,19 +243,27 @@ class NvrtcCompiler(Compiler):
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device'] '--gpu-architecture=sm_90a', '-default-device']
if cls.__version__() >= (12, 8): if cls.__version__() >= (12, 8):
base_flags += ['--pch', f'--pch-dir={get_cache_dir()}'] base_flags += ['--pch']
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
base_flags += ['--pch-verbose=true'] base_flags += ['--pch-verbose=true']
return base_flags return base_flags
@classmethod @classmethod
def compile(cls, name: str, src_path: str, target_path: str): def compile(cls, name: str, code: str, target_path: str) -> str:
code_bytes = open(src_path, 'rb').read() code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram( res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], []) code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}") raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
program, bytes(kernel_name, 'utf-8'))[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to add name expression: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()] options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram( compile_res = nvrtc.nvrtcCompileProgram(
program, len(options), options)[0] program, len(options), options)[0]
@@ -249,6 +282,12 @@ class NvrtcCompiler(Compiler):
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}") raise Exception(f"Failed to compile program: {compile_res}")
# NVRTC can directly get the lowered name
res, lowered_name = nvrtc.nvrtcGetLoweredName(
program, bytes(kernel_name, 'utf-8'))
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get lowered name: {res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}") raise Exception(f"Failed to get CUBIN size: {res}")
@@ -258,12 +297,14 @@ class NvrtcCompiler(Compiler):
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}") raise Exception(f"Failed to get CUBIN: {res}")
put(target_path, cubin_bytes, is_binary=True) put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0] res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}") raise Exception(f"Failed to destroy program: {res}")
return lowered_name.decode('utf-8')
def build(name: str, code: str) -> Runtime: def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:

View File

@@ -1,16 +1,36 @@
import os import os
import platform
import time import time
import subprocess
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime: class Runtime:
def __init__(self, path: str) -> None: def __init__(self, path: str, kernel_name: str) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.kernel = None self.kernel = None
self.kernel_name = kernel_name
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@@ -21,34 +41,24 @@ class Runtime:
return False return False
# Contains all necessary files # Contains all necessary files
files = ['kernel.cu', 'kernel.cubin'] files = ['kernel.cubin', 'kernel.cubin.name']
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN # Load CUBIN
if self.lib is None: if self.kernel is None:
start_time = time.time_ns() start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile( res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}") raise Exception(f"Failed to load library: {res}")
res, kernel_count = cuda.cuLibraryGetKernelCount(lib) res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel count: {res}") raise Exception(f"Failed to get kernel: {res}")
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to enumerate kernels: {res}")
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel name: {res}")
if b"fp8" in kernel_name:
self.kernel = kernel
break
self.kernel = kernel
if self.kernel is not None: if self.kernel is not None:
self.lib = lib self.lib = lib
else: else:
@@ -95,7 +105,8 @@ class RuntimeCache:
# Already compiled # Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path): if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = Runtime(path) kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime self.cache[path] = runtime
return runtime return runtime
return None return None

View File

@@ -22,7 +22,9 @@ def generate(**kwargs: Dict[str, Any]) -> str:
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh> #include <deep_gemm/fp8_gemm.cuh>
namespace deep_gemm {{ using namespace deep_gemm;
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{ __global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel< void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']}, {kwargs['N']},
@@ -41,7 +43,7 @@ __global__ void dummy_kernel() {{
GemmType::{kwargs['GEMM_TYPE']} GemmType::{kwargs['GEMM_TYPE']}
>; >;
}} }}
}} #endif
''' '''
# Debug print # Debug print