fix: windows compat

Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Gabriel Wu 2025-04-23 14:47:15 +08:00
parent 40c09fb883
commit 2d8c4f22d5
2 changed files with 28 additions and 9 deletions

View File

@ -2,6 +2,7 @@ import abc
import functools
import hashlib
import os
import platform
import re
import subprocess
import time
@ -51,15 +52,19 @@ def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_NVCC_COMPILER'):
paths.append(os.getenv('DG_NVCC_COMPILER'))
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
nvcc_bin = 'nvcc.exe' if platform.system() == 'Windows' else 'nvcc'
paths.append(os.path.join(CUDA_HOME, 'bin', nvcc_bin))
# Try to find the first available NVCC compiler
least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
match = version_pattern.search(
os.popen(f'{path} --version').read())
command = [path, '--version']
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
match = version_pattern.search(result.stdout)
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
@ -184,8 +189,11 @@ class NvccCompiler(Compiler):
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3',
'-Wno-deprecated-declarations', '-Wno-abi', '-fconcepts']
if platform.system() != 'Windows':
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
else:
cxx_flags = ['/O2', '/std:c++20']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
@ -203,8 +211,13 @@ class NvccCompiler(Compiler):
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_JIT_PRINT_NVCC_COMMAND', False):
print(f'Compiling JIT runtime {name} with command {command}')
return_code = subprocess.check_call(command)
assert return_code == 0, f'Failed to compile {src_path}'
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
if os.getenv('DG_JIT_DEBUG', None):
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
@ -213,7 +226,10 @@ class NvccCompiler(Compiler):
class NvrtcCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get actual NVRTC version, use bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return (major, minor)
@staticmethod

View File

@ -1,4 +1,5 @@
import os
import platform
import time
import subprocess
from typing import Any, Dict, Optional
@ -12,7 +13,9 @@ from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
command = [os.path.join(CUDA_HOME, 'bin', 'cuobjdump'), '-symbols', file_path]
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0