mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Merge 93ea4797c0 into e82c4139da
This commit is contained in:
commit
cc17efd000
@ -12,7 +12,7 @@ find_package(pybind11 REQUIRED)
|
|||||||
|
|
||||||
file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
|
file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
|
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -gencode arch=compute_120,code=sm_120 -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
|
||||||
RESULT_VARIABLE NVCC_RESULT
|
RESULT_VARIABLE NVCC_RESULT
|
||||||
OUTPUT_VARIABLE NVCC_OUTPUT
|
OUTPUT_VARIABLE NVCC_OUTPUT
|
||||||
ERROR_VARIABLE NVCC_ERROR_OUTPUT
|
ERROR_VARIABLE NVCC_ERROR_OUTPUT
|
||||||
@ -27,8 +27,11 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NVCC_SUPPORTS_SM90)
|
if (NVCC_SUPPORTS_SM90)
|
||||||
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
|
set(TORCH_CUDA_ARCH_LIST "8.6;9.0a" CACHE STRING "Add arch tag 90a to NVCC" FORCE) # TODO: Check if 9.0a is correct for sm_90a, it might be just 9.0
|
||||||
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||||
|
# Add Blackwell support if NVCC supports it (determined by the test_cuda.cu compilation)
|
||||||
|
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_120,code=sm_120")
|
||||||
|
set(TORCH_CUDA_ARCH_LIST "${TORCH_CUDA_ARCH_LIST};12.0" CACHE STRING "Add arch tag 120 to NVCC" FORCE) # TODO: Check if 12.0 is correct for sm_120
|
||||||
endif()
|
endif()
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
|
|||||||
@ -187,6 +187,7 @@ class NVCCCompiler(Compiler):
|
|||||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||||
'-gencode=arch=compute_90a,code=sm_90a',
|
'-gencode=arch=compute_90a,code=sm_90a',
|
||||||
|
'-gencode=arch=compute_120,code=sm_120',
|
||||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||||
f'--compiler-options={",".join(cxx_flags)}']
|
f'--compiler-options={",".join(cxx_flags)}']
|
||||||
|
|
||||||
@ -230,7 +231,7 @@ class NVRTCCompiler(Compiler):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def flags(cls) -> List[str]:
|
def flags(cls) -> List[str]:
|
||||||
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||||
'--gpu-architecture=sm_90a', '-default-device']
|
'--gpu-architecture=compute_90a', '--gpu-architecture=compute_120', '-default-device']
|
||||||
# NOTES: PCH is vital for compilation speed
|
# NOTES: PCH is vital for compilation speed
|
||||||
if cls.__version__() >= (12, 8):
|
if cls.__version__() >= (12, 8):
|
||||||
flags += ['--pch']
|
flags += ['--pch']
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user