Add initial support for Nvidia Blackwell (SM120)

This change introduces the necessary compiler flags and CMake configurations to enable support for the Nvidia Blackwell SM120 architecture.

- Modified deep_gemm/jit/compiler.py to include sm_120 and compute_120 flags for NVCC and NVRTC.
- Updated CMakeLists.txt to add the new architecture flags for the build process.

Further testing on Blackwell hardware is required to validate MMA instruction compatibility and overall performance.
This commit is contained in:
google-labs-jules[bot]
2025-06-24 00:30:35 +00:00
parent e82c4139da
commit 93ea4797c0
2 changed files with 7 additions and 3 deletions

View File

@@ -187,6 +187,7 @@ class NVCCCompiler(Compiler):
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-gencode=arch=compute_120,code=sm_120',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
@@ -230,7 +231,7 @@ class NVRTCCompiler(Compiler):
@classmethod
def flags(cls) -> List[str]:
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
'--gpu-architecture=compute_90a', '--gpu-architecture=compute_120', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
flags += ['--pch']