From 8702f910e37cb95c7e822b49b26f3b27b8f79b44 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 13:23:40 +0800 Subject: [PATCH] Fix 12.9 compatibility --- deep_gemm/jit/compiler.py | 3 +-- deep_gemm/jit/runtime.py | 3 ++- deep_gemm/jit_kernels/runtime.py | 34 +++++++++++++++++--------------- tests/test_jit.py | 4 +++- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 559a2f6..2ab6b25 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]: paths = [] if os.getenv('DG_JIT_NVCC_COMPILER'): paths.append(os.getenv('DG_JIT_NVCC_COMPILER')) - paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc')) # Try to find the first available NVCC compiler @@ -181,7 +180,7 @@ class NVCCCompiler(Compiler): @classmethod def signature(cls) -> str: - return f'nvcc+{cls.__version__()}' + return f'{get_nvcc_compiler()[0]}+{cls.__version__()}' @classmethod def flags(cls) -> List[str]: diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index b7c2f95..74ceff5 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -48,7 +48,8 @@ class Runtime: command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0 - kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')] + kernel_names = [line.split()[-1] for line in result.stdout.splitlines() + if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line] assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}' # Load kernel from the library diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 5396601..fa0a61d 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -175,22 +175,24 @@ class FP8GemmRuntime(Runtime): using namespace deep_gemm; -auto ptr = reinterpret_cast(&fp8_gemm_kernel< - {kwargs['N']}, - {kwargs['K']}, - {kwargs['BLOCK_M']}, - {kwargs['BLOCK_N']}, - {kwargs['BLOCK_K']}, - {kwargs['BLOCK_N_PADDING']}, - {kwargs['SWIZZLE_D_MODE']}, - {kwargs['NUM_GROUPS']}, - {kwargs['NUM_STAGES']}, - {kwargs['NUM_TMA_THREADS']}, - {kwargs['NUM_MATH_THREADS_PER_GROUP']}, - {kwargs['NUM_TMA_MULTICAST']}, - {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, - GemmType::{kwargs['GEMM_TYPE']} - >); +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&fp8_gemm_kernel< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['BLOCK_N_PADDING']}, + {kwargs['SWIZZLE_D_MODE']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, + GemmType::{kwargs['GEMM_TYPE']} + >); +}}; ''' if int(os.getenv('DG_JIT_DEBUG', 0)): print(f'Generated FP8 GEMM code:\n{code}') diff --git a/tests/test_jit.py b/tests/test_jit.py index 37b8bc4..fbd84e1 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -39,7 +39,9 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{ }} }} -auto ptr = reinterpret_cast(&vector_add); +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&vector_add<{kwargs['T']}>); +}} """ # noinspection PyShadowingNames,PyMethodOverriding