feat: drop support for CUDA<12.3

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-25 18:56:40 -07:00
parent 46762b6903
commit f6198492cb
5 changed files with 46 additions and 85 deletions

View File

@ -14,7 +14,7 @@ import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache
runtime_cache = RuntimeCache() runtime_cache = RuntimeCache()
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: def compile(cls, name: str, code: str, target_path: str) -> str:
pass pass
@staticmethod @staticmethod
@ -132,10 +132,9 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()] return [get_jit_include_dir()]
@classmethod @classmethod
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
# Compiler flags # Compiler flags
flags = cls.flags() flags = cls.flags()
include_dirs = cls.include_dirs()
# Build signature # Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int( enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int(
@ -158,7 +157,7 @@ class Compiler(abc.ABC):
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time() start_time = time.time()
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern) cls.compile(name, code, tmp_cubin_path)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
@ -169,15 +168,11 @@ class Compiler(abc.ABC):
if enable_sass_opt: if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path) interleave_ffma.process(tmp_cubin_path)
# Store kernel name
put(f'{tmp_cubin_path}.name', kernel_name)
# Atomic replace files # Atomic replace files
os.replace(tmp_cubin_path, cubin_path) os.replace(tmp_cubin_path, cubin_path)
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return # Put cache and return
runtime = runtime_cls(path, kernel_name) runtime = runtime_cls(path)
runtime_cache[path] = runtime runtime_cache[path] = runtime
return runtime return runtime
@ -202,7 +197,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}'] f'--compiler-options={",".join(cxx_flags)}']
@classmethod @classmethod
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: def compile(cls, name: str, code: str, target_path: str):
# Write the code # Write the code
path = os.path.join(get_cache_dir(), name) path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu') src_path = os.path.join(path, 'kernel.cu')
@ -221,9 +216,6 @@ class NvccCompiler(Compiler):
assert result.returncode == 0, f'Failed to compile {src_path}' assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, kernel_name_pattern)
class NvrtcCompiler(Compiler): class NvrtcCompiler(Compiler):
@staticmethod @staticmethod
@ -238,7 +230,7 @@ class NvrtcCompiler(Compiler):
def include_dirs() -> List[str]: def include_dirs() -> List[str]:
if CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation') raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')] return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
@classmethod @classmethod
def flags(cls) -> List[str]: def flags(cls) -> List[str]:
@ -251,20 +243,12 @@ class NvrtcCompiler(Compiler):
return base_flags return base_flags
@classmethod @classmethod
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: def compile(cls, name: str, code: str, target_path: str) -> str:
code_bytes = bytes(code, 'utf-8') code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram( res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], []) code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}") raise Exception(f'Failed to create program: {res}')
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
program, bytes(kernel_name, 'utf-8'))[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to add name expression: {res}")
options = [bytes(flag, 'utf-8') for flag in cls.flags()] options = [bytes(flag, 'utf-8') for flag in cls.flags()]
compile_res = nvrtc.nvrtcCompileProgram( compile_res = nvrtc.nvrtcCompileProgram(
@ -273,43 +257,35 @@ class NvrtcCompiler(Compiler):
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
res, log_size = nvrtc.nvrtcGetProgramLogSize(program) res, log_size = nvrtc.nvrtcGetProgramLogSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log size: {res}") raise Exception(f'Failed to get program log size: {res}')
log_bytes = bytes(log_size) log_bytes = bytes(log_size)
res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get program log: {res}") raise Exception(f'Failed to get program log: {res}')
log_str = log_bytes.decode('utf-8') log_str = log_bytes.decode('utf-8')
print(log_str) print(log_str)
if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to compile program: {compile_res}") raise Exception(f'Failed to compile program: {compile_res}')
# NVRTC can directly get the lowered name
res, lowered_name = nvrtc.nvrtcGetLoweredName(
program, bytes(kernel_name, 'utf-8'))
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get lowered name: {res}")
res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) res, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN size: {res}") raise Exception(f'Failed to get CUBIN size: {res}')
cubin_bytes = bytes(cubin_size) cubin_bytes = bytes(cubin_size)
res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to get CUBIN: {res}") raise Exception(f'Failed to get CUBIN: {res}')
put(target_path, cubin_bytes) put(target_path, cubin_bytes)
res = nvrtc.nvrtcDestroyProgram(program)[0] res = nvrtc.nvrtcDestroyProgram(program)[0]
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to destroy program: {res}") raise Exception(f'Failed to destroy program: {res}')
return lowered_name.decode('utf-8')
def build(name: str, code: str) -> Runtime: def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>') return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls)
else: else:
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel') return NvccCompiler.build(name, code, runtime_cls=runtime_cls)

View File

@ -1,31 +1,12 @@
import os import os
import platform
import time import time
import subprocess
from typing import Any, Callable, Dict, List, Optional, Type from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
from .utils import run_gemm from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
'-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
for line in result.stdout.splitlines():
if pattern in line:
return line.split()[-1]
return None
class Runtime: class Runtime:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None: def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path self.path = path
@ -43,7 +24,7 @@ class Runtime:
return False return False
# Contains all necessary files # Contains all necessary files
files = ['kernel.cubin', 'kernel.cubin.name'] files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
@ -53,18 +34,28 @@ class Runtime:
res, lib = cuda.cuLibraryLoadFromFile( res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}") raise Exception(f'Failed to load library: {res}')
res, kernel = cuda.cuLibraryGetKernel( res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to get kernel: {res}") raise Exception(f'Failed to get kernel count: {res}')
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to enumerate kernels: {res}')
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel name: {res}')
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
self.kernel = kernel self.kernel = kernel
break
if self.kernel is not None: if self.kernel is not None:
self.lib = lib self.lib = lib
else: else:
raise Exception("Failed to find kernel") raise Exception('Failed to find required kernel')
end_time = time.time_ns() end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000 elapsed_time = (end_time - start_time) / 1000
@ -81,12 +72,12 @@ class Runtime:
if self.lib is not None: if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0] res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to unload library {self.path}: {res}") raise Exception(f'Failed to unload library {self.path}: {res}')
class Fp8GemmRuntime(Runtime): class Fp8GemmRuntime(Runtime):
def __init__(self, path: str, kernel_name: str) -> None: def __init__(self, path: str) -> None:
super().__init__(path, kernel_name, run_gemm, [ super().__init__(path, 'fp8_gemm', run_gemm, [
'NUM_TMA_MULTICAST', 'NUM_TMA_MULTICAST',
'M', 'M',
'BLOCK_M', 'BLOCK_M',
@ -117,9 +108,7 @@ class RuntimeCache:
# Already compiled # Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path): if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join( runtime = runtime_cls(path)
path, 'kernel.cubin.name'), 'r').read()
runtime = runtime_cls(path, kernel_name)
self.cache[path] = runtime self.cache[path] = runtime
return runtime return runtime
return None return None

View File

@ -24,7 +24,6 @@ def generate(**kwargs: Dict[str, Any]) -> str:
using namespace deep_gemm; using namespace deep_gemm;
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{ __global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel< void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']}, {kwargs['N']},
@ -43,7 +42,6 @@ __global__ void dummy_kernel() {{
GemmType::{kwargs['GEMM_TYPE']} GemmType::{kwargs['GEMM_TYPE']}
>; >;
}} }}
#endif
''' '''
# Debug print # Debug print

View File

@ -53,7 +53,7 @@ def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2 return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, "Only support 128 threads per math group" assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
@ -74,7 +74,7 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuu
) )
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to encode tensor map: {res}") raise Exception(f'Failed to encode tensor map: {res}')
return tensor_map return tensor_map
@ -118,7 +118,7 @@ def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_
res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0] res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0]
if res != cuda.CUresult.CUDA_SUCCESS: if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to set max dynamic shared memory size: {res}") raise Exception(f'Failed to set max dynamic shared memory size: {res}')
attr_val = cuda.CUlaunchAttributeValue() attr_val = cuda.CUlaunchAttributeValue()
attr_val.clusterDim.x = num_tma_multicast attr_val.clusterDim.x = num_tma_multicast

View File

@ -62,17 +62,15 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
}} }}
}} }}
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{ __global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>; void *ptr = (void *)&vector_add<{kwargs['T']}>;
}} }}
#endif
""" """
class VectorAddRuntime(jit.Runtime): class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str, kernel_name: str) -> None: def __init__(self, path: str) -> None:
super().__init__(path, kernel_name, run_vector_add, [ super().__init__(path, 'vector_add', run_vector_add, [
'A', 'A',
'B', 'B',
'C', 'C',
@ -87,7 +85,7 @@ if __name__ == '__main__':
code = generate_vector_add(T='float') code = generate_vector_add(T='float')
print(code) print(code)
print('Building ...') print('Building ...')
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime) func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.float32, device='cuda') a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda') b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
@ -105,7 +103,7 @@ if __name__ == '__main__':
code = generate_vector_add(T='__nv_bfloat16') code = generate_vector_add(T='__nv_bfloat16')
print(code) print(code)
print('Building ...') print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime) func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')