Several fixes

This commit is contained in:
Chenggang Zhao
2025-05-07 09:57:39 +08:00
parent 317e83581d
commit 83f6e9537e
4 changed files with 30 additions and 11 deletions

View File

@@ -105,6 +105,10 @@ def put(path, data):
class Compiler:
@classmethod
def signature(cls) -> str:
pass
@staticmethod
def __version__() -> Tuple[int, int]:
pass
@@ -115,7 +119,7 @@ class Compiler:
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
cpp_standard = int(os.getenv('DG_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
@@ -132,8 +136,8 @@ class Compiler:
flags = cls.flags()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
@@ -177,6 +181,10 @@ class NVCCCompiler(Compiler):
major, minor = map(int, version.split('.'))
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvcc+{cls.__version__()}'
@classmethod
def flags(cls) -> List[str]:
if platform.system() != 'Windows':
@@ -216,6 +224,10 @@ class NVRTCCompiler(Compiler):
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvrtc+{cls.__version__()}'
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
@@ -224,13 +236,14 @@ class NVRTCCompiler(Compiler):
@classmethod
def flags(cls) -> List[str]:
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
base_flags += ['--pch']
flags += ['--pch']
if int(os.getenv('DG_JIT_DEBUG', 0)):
base_flags += ['--pch-verbose=true']
return base_flags
flags += ['--pch-verbose=true']
return flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None: