mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add initial support for Nvidia Blackwell (SM120)
This change introduces the necessary compiler flags and CMake configurations to enable support for the Nvidia Blackwell SM120 architecture. - Modified deep_gemm/jit/compiler.py to include sm_120 and compute_120 flags for NVCC and NVRTC. - Updated CMakeLists.txt to add the new architecture flags for the build process. Further testing on Blackwell hardware is required to validate MMA instruction compatibility and overall performance.
This commit is contained in:
parent
e82c4139da
commit
93ea4797c0
@ -12,7 +12,7 @@ find_package(pybind11 REQUIRED)
|
||||
|
||||
file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
|
||||
execute_process(
|
||||
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
|
||||
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -gencode arch=compute_120,code=sm_120 -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
|
||||
RESULT_VARIABLE NVCC_RESULT
|
||||
OUTPUT_VARIABLE NVCC_OUTPUT
|
||||
ERROR_VARIABLE NVCC_ERROR_OUTPUT
|
||||
@ -27,8 +27,11 @@ else()
|
||||
endif()
|
||||
|
||||
if (NVCC_SUPPORTS_SM90)
|
||||
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
|
||||
set(TORCH_CUDA_ARCH_LIST "8.6;9.0a" CACHE STRING "Add arch tag 90a to NVCC" FORCE) # TODO: Check if 9.0a is correct for sm_90a, it might be just 9.0
|
||||
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||
# Add Blackwell support if NVCC supports it (determined by the test_cuda.cu compilation)
|
||||
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_120,code=sm_120")
|
||||
set(TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST};12.0" CACHE STRING "Add arch tag 120 to NVCC" FORCE) # TODO: Check if 12.0 is correct for sm_120
|
||||
endif()
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
|
||||
@ -187,6 +187,7 @@ class NVCCCompiler(Compiler):
|
||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-gencode=arch=compute_120,code=sm_120',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@ -230,7 +231,7 @@ class NVRTCCompiler(Compiler):
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
'--gpu-architecture=compute_90a', '--gpu-architecture=compute_120', '-default-device']
|
||||
# NOTES: PCH is vital for compilation speed
|
||||
if cls.__version__() >= (12, 8):
|
||||
flags += ['--pch']
|
||||
|
||||
Loading…
Reference in New Issue
Block a user