mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-12 09:20:33 +00:00
Fix 12.9 compatibility
This commit is contained in:
parent
085b4a1532
commit
8702f910e3
@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
|||||||
paths = []
|
paths = []
|
||||||
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
||||||
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
||||||
|
|
||||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||||
|
|
||||||
# Try to find the first available NVCC compiler
|
# Try to find the first available NVCC compiler
|
||||||
@ -181,7 +180,7 @@ class NVCCCompiler(Compiler):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def signature(cls) -> str:
|
def signature(cls) -> str:
|
||||||
return f'nvcc+{cls.__version__()}'
|
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def flags(cls) -> List[str]:
|
def flags(cls) -> List[str]:
|
||||||
|
@ -48,7 +48,8 @@ class Runtime:
|
|||||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
assert result.returncode == 0
|
assert result.returncode == 0
|
||||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
|
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||||
|
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
||||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||||
|
|
||||||
# Load kernel from the library
|
# Load kernel from the library
|
||||||
|
@ -175,22 +175,24 @@ class FP8GemmRuntime(Runtime):
|
|||||||
|
|
||||||
using namespace deep_gemm;
|
using namespace deep_gemm;
|
||||||
|
|
||||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
static void __instantiate_kernel() {{
|
||||||
{kwargs['N']},
|
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||||
{kwargs['K']},
|
{kwargs['N']},
|
||||||
{kwargs['BLOCK_M']},
|
{kwargs['K']},
|
||||||
{kwargs['BLOCK_N']},
|
{kwargs['BLOCK_M']},
|
||||||
{kwargs['BLOCK_K']},
|
{kwargs['BLOCK_N']},
|
||||||
{kwargs['BLOCK_N_PADDING']},
|
{kwargs['BLOCK_K']},
|
||||||
{kwargs['SWIZZLE_D_MODE']},
|
{kwargs['BLOCK_N_PADDING']},
|
||||||
{kwargs['NUM_GROUPS']},
|
{kwargs['SWIZZLE_D_MODE']},
|
||||||
{kwargs['NUM_STAGES']},
|
{kwargs['NUM_GROUPS']},
|
||||||
{kwargs['NUM_TMA_THREADS']},
|
{kwargs['NUM_STAGES']},
|
||||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
{kwargs['NUM_TMA_THREADS']},
|
||||||
{kwargs['NUM_TMA_MULTICAST']},
|
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
{kwargs['NUM_TMA_MULTICAST']},
|
||||||
GemmType::{kwargs['GEMM_TYPE']}
|
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||||
>);
|
GemmType::{kwargs['GEMM_TYPE']}
|
||||||
|
>);
|
||||||
|
}};
|
||||||
'''
|
'''
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||||
print(f'Generated FP8 GEMM code:\n{code}')
|
print(f'Generated FP8 GEMM code:\n{code}')
|
||||||
|
@ -39,7 +39,9 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
|
|||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
|
|
||||||
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
|
static void __instantiate_kernel() {{
|
||||||
|
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||||
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# noinspection PyShadowingNames,PyMethodOverriding
|
# noinspection PyShadowingNames,PyMethodOverriding
|
||||||
|
Loading…
Reference in New Issue
Block a user