mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
fix: windows compat
Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
40c09fb883
commit
2d8c4f22d5
@ -2,6 +2,7 @@ import abc
|
||||
import functools
|
||||
import hashlib
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
@ -51,15 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]:
|
||||
paths = []
|
||||
if os.getenv('DG_NVCC_COMPILER'):
|
||||
paths.append(os.getenv('DG_NVCC_COMPILER'))
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
|
||||
|
||||
nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc'
|
||||
paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin))
|
||||
|
||||
# Try to find the first available NVCC compiler
|
||||
least_version_required = '12.3'
|
||||
version_pattern = re.compile(r'release (\d+\.\d+)')
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
match = version_pattern.search(
|
||||
os.popen(f'{path} --version').read())
|
||||
command = [path, '--version']
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
match = version_pattern.search(result.stdout)
|
||||
version = match.group(1)
|
||||
assert match, f'Cannot get the version of NVCC compiler {path}'
|
||||
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
|
||||
@ -184,8 +189,11 @@ class NvccCompiler(Compiler):
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
cxx_flags = ['-fPIC', '-O3',
|
||||
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
|
||||
if platform.system() != 'Windows':
|
||||
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
|
||||
else:
|
||||
cxx_flags = ['/O2', '/std:c++20']
|
||||
|
||||
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'-gencode=arch=compute_90a,code=sm_90a',
|
||||
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
|
||||
@ -203,8 +211,13 @@ class NvccCompiler(Compiler):
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
|
||||
print(f'Compiling JIT runtime {name} with command {command}')
|
||||
|
||||
return_code = subprocess.check_call(command)
|
||||
assert return_code == 0, f'Failed to compile {src_path}'
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
@ -213,7 +226,10 @@ class NvccCompiler(Compiler):
|
||||
class NvrtcCompiler(Compiler):
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
res, major, minor = nvrtc.nvrtcVersion()
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
# Failed to get actual NVRTC version, use bindings version instead
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return (major, minor)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
@ -12,7 +13,9 @@ from .utils import run_gemm
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
command = [os.path.join(CUDA_HOME, 'bin', 'cuobjdump'), '-symbols', file_path]
|
||||
|
||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user