From ba349d9cf84938a9fd22c47afd902a48adbc9a46 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 11:15:19 +0800 Subject: [PATCH] Compatible with CUDA 12.3 --- deep_gemm/jit/runtime.py | 59 +++++++++++++------------------- deep_gemm/jit_kernels/runtime.py | 6 ++-- tests/test_jit.py | 12 +++---- 3 files changed, 30 insertions(+), 47 deletions(-) diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 79ba6b8..5d17d86 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -1,18 +1,17 @@ import os +import subprocess import time import cuda.bindings.driver as cbd -from typing import List, Optional, Type -import cuda.bindings.driver as cuda +from typing import List, Optional, Type +from torch.utils.cpp_extension import CUDA_HOME class Runtime: - def __init__(self, path: str, kernel_name: str = None, - args: List[str] = None) -> None: + def __init__(self, path: str, args: List[str] = None) -> None: self.path = path self.lib = None self.kernel = None - self.kernel_name = kernel_name self.args = args assert self.is_path_valid(self.path) @@ -34,51 +33,39 @@ class Runtime: def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult: raise NotImplemented - def __call__(self, **kwargs) -> cuda.CUresult: + def __call__(self, **kwargs) -> cbd.CUresult: # Load CUBIN if self.kernel is None: start_time = time.time_ns() - res, lib = cuda.cuLibraryLoadFromFile( - bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to load library: {res}') - res, kernel_count = cuda.cuLibraryGetKernelCount(lib) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to get kernel count: {res}') - - res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to enumerate kernels: {res}') + # Load CUBIN + path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8') + result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}' - for kernel in kernels: - res, kernel_name = cuda.cuKernelGetName(kernel) - if res != cuda.CUresult.CUDA_SUCCESS: - raise Exception(f'Failed to get kernel name: {res}') - if bytes(self.kernel_name, encoding='utf-8') in kernel_name: - self.kernel = kernel - break + # Extract the kernel name + command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + assert result.returncode == 0 + kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')] + assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' - if self.kernel is not None: - self.lib = lib - else: - raise Exception('Failed to find required kernel') + # Load kernel from the library + result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8')) + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}' end_time = time.time_ns() - elapsed_time = (end_time - start_time) / 1000 + elapsed_time = (end_time - start_time) / 1e6 if int(os.getenv('DG_JIT_DEBUG', 0)): - print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') + print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.') # noinspection PyArgumentList - return self.launch( - self.kernel, - *[kwargs[arg] for arg in self.args] - ) + return self.launch(self.kernel, *[kwargs[arg] for arg in self.args]) def __del__(self) -> None: if self.lib is not None: - res = cuda.cuLibraryUnload(self.lib)[0] - if res != cuda.CUresult.CUDA_SUCCESS: + res = cbd.cuLibraryUnload(self.lib)[0] + if res != cbd.CUresult.CUDA_SUCCESS: raise Exception(f'Failed to unload library {self.path}: {res}') diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index a5fe16b..5396601 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -142,7 +142,7 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, class FP8GemmRuntime(Runtime): def __init__(self, path: str) -> None: - super().__init__(path, 'fp8_gemm', [ + super().__init__(path, [ 'NUM_TMA_MULTICAST', 'M', 'BLOCK_M', @@ -175,8 +175,7 @@ class FP8GemmRuntime(Runtime): using namespace deep_gemm; -__global__ void dummy_kernel() {{ - auto ptr = reinterpret_cast(&fp8_gemm_kernel< +auto ptr = reinterpret_cast(&fp8_gemm_kernel< {kwargs['N']}, {kwargs['K']}, {kwargs['BLOCK_M']}, @@ -192,7 +191,6 @@ __global__ void dummy_kernel() {{ {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, GemmType::{kwargs['GEMM_TYPE']} >); -}} ''' if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Generated FP8 GEMM code:\n{code}') diff --git a/tests/test_jit.py b/tests/test_jit.py index 1ba0d16..37b8bc4 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -12,7 +12,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1') class VectorAddRuntime(jit.Runtime): def __init__(self, path: str) -> None: - super().__init__(path, 'vector_add', [ + super().__init__(path, [ 'A', 'B', 'C', @@ -31,17 +31,15 @@ class VectorAddRuntime(jit.Runtime): #include #include -template -__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{ +template +__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < N) {{ + if (i < n) {{ c[i] = a[i] + b[i]; }} }} -__global__ void dummy_kernel() {{ - auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); -}} +auto ptr = reinterpret_cast(&vector_add); """ # noinspection PyShadowingNames,PyMethodOverriding