mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Fix 12.9 compatibility
This commit is contained in:
@@ -50,7 +50,6 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_JIT_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
|
||||
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
@@ -181,7 +180,7 @@ class NVCCCompiler(Compiler):
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvcc+{cls.__version__()}'
|
||||
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
|
||||
@@ -48,7 +48,8 @@ class Runtime:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
# Load kernel from the library
|
||||
|
||||
Reference in New Issue
Block a user