From 93ea4797c071aa61a42053b78d554e8031dd5928 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 24 Jun 2025 00:30:35 +0000 Subject: [PATCH] Add initial support for Nvidia Blackwell (SM120) This change introduces the necessary compiler flags and CMake configurations to enable support for the Nvidia Blackwell SM120 architecture. - Modified deep_gemm/jit/compiler.py to include sm_120 and compute_120 flags for NVCC and NVRTC. - Updated CMakeLists.txt to add the new architecture flags for the build process. Further testing on Blackwell hardware is required to validate MMA instruction compatibility and overall performance. --- CMakeLists.txt | 7 +++++-- deep_gemm/jit/compiler.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 658aa7b..621acb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ find_package(pybind11 REQUIRED) file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }") execute_process( - COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu + COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -gencode arch=compute_120,code=sm_120 -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu RESULT_VARIABLE NVCC_RESULT OUTPUT_VARIABLE NVCC_OUTPUT ERROR_VARIABLE NVCC_ERROR_OUTPUT @@ -27,8 +27,11 @@ else() endif() if (NVCC_SUPPORTS_SM90) - set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE) + set(TORCH_CUDA_ARCH_LIST "8.6;9.0a" CACHE STRING "Add arch tag 90a to NVCC" FORCE) # TODO: Check if 9.0a is correct for sm_90a, it might be just 9.0 list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a") + # Add Blackwell support if NVCC supports it (determined by the test_cuda.cu compilation) + list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_120,code=sm_120") + set(TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST};12.0" CACHE STRING "Add arch tag 120 to NVCC" FORCE) # TODO: Check if 12.0 is correct for sm_120 endif() find_package(Torch REQUIRED) diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index d3f1f76..2cd8346 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -187,6 +187,7 @@ class NVCCCompiler(Compiler): cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi'] return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], '-gencode=arch=compute_90a,code=sm_90a', + '-gencode=arch=compute_120,code=sm_120', '-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda', f'--compiler-options={",".join(cxx_flags)}'] @@ -230,7 +231,7 @@ class NVRTCCompiler(Compiler): @classmethod def flags(cls) -> List[str]: flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '--gpu-architecture=sm_90a', '-default-device'] + '--gpu-architecture=compute_90a', '--gpu-architecture=compute_120', '-default-device'] # NOTES: PCH is vital for compilation speed if cls.__version__() >= (12, 8): flags += ['--pch']