Fix 12.9 compatibility

This commit is contained in:
Chenggang Zhao 2025-05-07 13:23:40 +08:00
parent 085b4a1532
commit 8702f910e3
4 changed files with 24 additions and 20 deletions

View File

@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_JIT_NVCC_COMPILER'):
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
# Try to find the first available NVCC compiler
@ -181,7 +180,7 @@ class NVCCCompiler(Compiler):
@classmethod
def signature(cls) -> str:
return f'nvcc+{cls.__version__()}'
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
@classmethod
def flags(cls) -> List[str]:

View File

@ -48,7 +48,8 @@ class Runtime:
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library

View File

@ -175,22 +175,24 @@ class FP8GemmRuntime(Runtime):
using namespace deep_gemm;
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>);
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')

View File

@ -39,7 +39,9 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
}}
}}
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
"""
# noinspection PyShadowingNames,PyMethodOverriding