mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-11 15:00:45 +00:00
Fix 12.9 compatibility
This commit is contained in:
parent
085b4a1532
commit
8702f910e3
@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
||||
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
@ -181,7 +180,7 @@ class NVCCCompiler(Compiler):
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvcc+{cls.__version__()}'
|
||||
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
|
@ -48,7 +48,8 @@ class Runtime:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
# Load kernel from the library
|
||||
|
@ -175,22 +175,24 @@ class FP8GemmRuntime(Runtime):
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>);
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>);
|
||||
}};
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 GEMM code:\n{code}')
|
||||
|
@ -39,7 +39,9 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
|
||||
}}
|
||||
}}
|
||||
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
|
Loading…
Reference in New Issue
Block a user