From f6198492cb0c21ae3b75e0e324e230cfdc37bc37 Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Fri, 25 Apr 2025 18:56:40 -0700 Subject: [PATCH] feat: drop support for CUDA<12.3 Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> --- deep_gemm/jit/compiler.py | 60 ++++++++++++--------------------------- deep_gemm/jit/runtime.py | 53 ++++++++++++++-------------------- deep_gemm/jit/template.py | 2 -- deep_gemm/jit/utils.py | 6 ++-- tests/test_jit.py | 10 +++---- 5 files changed, 46 insertions(+), 85 deletions(-) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 549be37..36a3361 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -14,7 +14,7 @@ import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME from . import interleave_ffma -from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol +from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache runtime_cache = RuntimeCache() @@ -115,7 +115,7 @@ class Compiler(abc.ABC): @classmethod @abc.abstractmethod - def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: + def compile(cls, name: str, code: str, target_path: str) -> str: pass @staticmethod @@ -132,10 +132,9 @@ class Compiler(abc.ABC): return [get_jit_include_dir()] @classmethod - def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: + def build(cls, name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: # Compiler flags flags = cls.flags() - include_dirs = cls.include_dirs() # Build signature enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and int( @@ -158,7 +157,7 @@ class Compiler(abc.ABC): tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') start_time = time.time() - kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern) + cls.compile(name, code, tmp_cubin_path) end_time = time.time() elapsed_time = end_time - start_time if os.getenv('DG_JIT_DEBUG', None): @@ -169,15 +168,11 @@ class Compiler(abc.ABC): if enable_sass_opt: interleave_ffma.process(tmp_cubin_path) - # Store kernel name - put(f'{tmp_cubin_path}.name', kernel_name) - # Atomic replace files os.replace(tmp_cubin_path, cubin_path) - os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name') # Put cache and return - runtime = runtime_cls(path, kernel_name) + runtime = runtime_cls(path) runtime_cache[path] = runtime return runtime @@ -202,7 +197,7 @@ class NvccCompiler(Compiler): f'--compiler-options={",".join(cxx_flags)}'] @classmethod - def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: + def compile(cls, name: str, code: str, target_path: str): # Write the code path = os.path.join(get_cache_dir(), name) src_path = os.path.join(path, 'kernel.cu') @@ -221,9 +216,6 @@ class NvccCompiler(Compiler): assert result.returncode == 0, f'Failed to compile {src_path}' - # NVCC needs to get the symbol name from the cubin file using `cuobjdump` - return get_symbol(target_path, kernel_name_pattern) - class NvrtcCompiler(Compiler): @staticmethod @@ -238,7 +230,7 @@ class NvrtcCompiler(Compiler): def include_dirs() -> List[str]: if CUDA_HOME is None: raise RuntimeError('CUDA_HOME is required for NVRTC compilation') - return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include'), os.path.join(CUDA_HOME, 'targets', 'x86_64-linux', 'include')] + return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')] @classmethod def flags(cls) -> List[str]: @@ -251,20 +243,12 @@ class NvrtcCompiler(Compiler): return base_flags @classmethod - def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: + def compile(cls, name: str, code: str, target_path: str) -> str: code_bytes = bytes(code, 'utf-8') res, program = nvrtc.nvrtcCreateProgram( code_bytes, bytes(name, 'utf-8'), 0, [], []) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to create program: {res}") - - kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE) - kernel_name = kernel_regex.search(code).group( - 0).replace('\n', '').replace(' ', '') - res = nvrtc.nvrtcAddNameExpression( - program, bytes(kernel_name, 'utf-8'))[0] - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to add name expression: {res}") + raise Exception(f'Failed to create program: {res}') options = [bytes(flag, 'utf-8') for flag in cls.flags()] compile_res = nvrtc.nvrtcCompileProgram( @@ -273,43 +257,35 @@ class NvrtcCompiler(Compiler): if os.getenv('DG_JIT_DEBUG', None): res, log_size = nvrtc.nvrtcGetProgramLogSize(program) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to get program log size: {res}") + raise Exception(f'Failed to get program log size: {res}') log_bytes = bytes(log_size) res = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0] if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to get program log: {res}") + raise Exception(f'Failed to get program log: {res}') log_str = log_bytes.decode('utf-8') print(log_str) if compile_res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to compile program: {compile_res}") - - # NVRTC can directly get the lowered name - res, lowered_name = nvrtc.nvrtcGetLoweredName( - program, bytes(kernel_name, 'utf-8')) - if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to get lowered name: {res}") + raise Exception(f'Failed to compile program: {compile_res}') res, cubin_size = nvrtc.nvrtcGetCUBINSize(program) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to get CUBIN size: {res}") + raise Exception(f'Failed to get CUBIN size: {res}') cubin_bytes = bytes(cubin_size) res = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0] if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to get CUBIN: {res}") + raise Exception(f'Failed to get CUBIN: {res}') put(target_path, cubin_bytes) res = nvrtc.nvrtcDestroyProgram(program)[0] if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise Exception(f"Failed to destroy program: {res}") - - return lowered_name.decode('utf-8') + raise Exception(f'Failed to destroy program: {res}') -def build(name: str, code: str) -> Runtime: +def build(name: str, code: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: - return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>') + return NvrtcCompiler.build(name, code, runtime_cls=runtime_cls) else: - return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel') + return NvccCompiler.build(name, code, runtime_cls=runtime_cls) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 2588dae..e5f7bfb 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,31 +1,12 @@ import os -import platform import time -import subprocess from typing import Any, Callable, Dict, List, Optional, Type import cuda.bindings.driver as cuda -from torch.utils.cpp_extension import CUDA_HOME from .utils import run_gemm -def get_symbol(file_path: str, pattern: str) -> Optional[str]: - if CUDA_HOME is None: - raise Exception("CUDA_HOME is not set") - - cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump' - command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), - '-symbols', file_path] - result = subprocess.run(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, text=True) - assert result.returncode == 0 - for line in result.stdout.splitlines(): - if pattern in line: - return line.split()[-1] - return None - - class Runtime: def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None: self.path = path @@ -43,7 +24,7 @@ class Runtime: return False # Contains all necessary files - files = ['kernel.cubin', 'kernel.cubin.name'] + files = ['kernel.cubin'] return all(os.path.exists(os.path.join(path, file)) for file in files) def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult: @@ -53,18 +34,28 @@ class Runtime: res, lib = cuda.cuLibraryLoadFromFile( bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to load library: {res}") + raise Exception(f'Failed to load library: {res}') - res, kernel = cuda.cuLibraryGetKernel( - lib, bytes(self.kernel_name, encoding='utf-8')) + res, kernel_count = cuda.cuLibraryGetKernelCount(lib) if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to get kernel: {res}") + raise Exception(f'Failed to get kernel count: {res}') + + res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to enumerate kernels: {res}') + + for kernel in kernels: + res, kernel_name = cuda.cuKernelGetName(kernel) + if res != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f'Failed to get kernel name: {res}') + if bytes(self.kernel_name, encoding='utf-8') in kernel_name: + self.kernel = kernel + break - self.kernel = kernel if self.kernel is not None: self.lib = lib else: - raise Exception("Failed to find kernel") + raise Exception('Failed to find required kernel') end_time = time.time_ns() elapsed_time = (end_time - start_time) / 1000 @@ -81,12 +72,12 @@ class Runtime: if self.lib is not None: res = cuda.cuLibraryUnload(self.lib)[0] if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to unload library {self.path}: {res}") + raise Exception(f'Failed to unload library {self.path}: {res}') class Fp8GemmRuntime(Runtime): - def __init__(self, path: str, kernel_name: str) -> None: - super().__init__(path, kernel_name, run_gemm, [ + def __init__(self, path: str) -> None: + super().__init__(path, 'fp8_gemm', run_gemm, [ 'NUM_TMA_MULTICAST', 'M', 'BLOCK_M', @@ -117,9 +108,7 @@ class RuntimeCache: # Already compiled if os.path.exists(path) and Runtime.is_path_valid(path): - kernel_name = open(os.path.join( - path, 'kernel.cubin.name'), 'r').read() - runtime = runtime_cls(path, kernel_name) + runtime = runtime_cls(path) self.cache[path] = runtime return runtime return None \ No newline at end of file diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index 6fd1322..461691f 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -24,7 +24,6 @@ def generate(**kwargs: Dict[str, Any]) -> str: using namespace deep_gemm; -#ifndef NVRTC_JIT_COMPILATION __global__ void dummy_kernel() {{ void *ptr = (void *)&fp8_gemm_kernel< {kwargs['N']}, @@ -43,7 +42,6 @@ __global__ void dummy_kernel() {{ GemmType::{kwargs['GEMM_TYPE']} >; }} -#endif ''' # Debug print diff --git a/deep_gemm/jit/utils.py b/deep_gemm/jit/utils.py index f1feefa..1321f24 100644 --- a/deep_gemm/jit/utils.py +++ b/deep_gemm/jit/utils.py @@ -53,7 +53,7 @@ def get_num_math_warpgroups(block_m: int) -> int: return 1 if block_m == 64 else 2 def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int: - assert num_math_threads_per_group == 128, "Only support 128 threads per math group" + assert num_math_threads_per_group == 128, 'Only support 128 threads per math group' return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads @@ -74,7 +74,7 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor, gmem_dim: Tuple[cuda.cuu ) if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to encode tensor map: {res}") + raise Exception(f'Failed to encode tensor map: {res}') return tensor_map @@ -118,7 +118,7 @@ def run_gemm(kernel: cuda.CUkernel, num_tma_multicast: int, shape_m: int, block_ res = cuda.cuKernelSetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cuda.CUdevice(gmem_d.device.index))[0] if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f"Failed to set max dynamic shared memory size: {res}") + raise Exception(f'Failed to set max dynamic shared memory size: {res}') attr_val = cuda.CUlaunchAttributeValue() attr_val.clusterDim.x = num_tma_multicast diff --git a/tests/test_jit.py b/tests/test_jit.py index 00bea1a..cced2d6 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -62,17 +62,15 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{ }} }} -#ifndef NVRTC_JIT_COMPILATION __global__ void dummy_kernel() {{ void *ptr = (void *)&vector_add<{kwargs['T']}>; }} -#endif """ class VectorAddRuntime(jit.Runtime): - def __init__(self, path: str, kernel_name: str) -> None: - super().__init__(path, kernel_name, run_vector_add, [ + def __init__(self, path: str) -> None: + super().__init__(path, 'vector_add', run_vector_add, [ 'A', 'B', 'C', @@ -87,7 +85,7 @@ if __name__ == '__main__': code = generate_vector_add(T='float') print(code) print('Building ...') - func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime) + func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime) a = torch.randn((1024, ), dtype=torch.float32, device='cuda') b = torch.randn((1024, ), dtype=torch.float32, device='cuda') @@ -105,7 +103,7 @@ if __name__ == '__main__': code = generate_vector_add(T='__nv_bfloat16') print(code) print('Building ...') - func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime) + func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime) a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')