From 83f6e9537ea5df8704a35109f91f3c53012f52f6 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 09:57:39 +0800 Subject: [PATCH] Several fixes --- README.md | 3 ++- deep_gemm/jit/compiler.py | 29 +++++++++++++++++++++-------- deep_gemm/jit_kernels/runtime.py | 4 ++-- tests/test_core.py | 5 +++++ 4 files changed, 30 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 6d0689f..a9d7d1f 100644 --- a/README.md +++ b/README.md @@ -108,12 +108,13 @@ The library also provides some environment variables, which may be useful: - `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default - `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default -- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler +- `DG_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler - `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization - `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output - `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details - `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command - `DG_JIT_DEBUG`: 0 or 1, print more debugging information +- `DG_JIT_USE_NVRTC`: 0 or 1, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, 0 by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 238c0d1..cf07889 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -105,6 +105,10 @@ def put(path, data): class Compiler: + @classmethod + def signature(cls) -> str: + pass + @staticmethod def __version__() -> Tuple[int, int]: pass @@ -115,7 +119,7 @@ class Compiler: @staticmethod def flags() -> List[str]: - cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20)) + cpp_standard = int(os.getenv('DG_OVERRIDE_CPP_STANDARD', 20)) return [f'-std=c++{cpp_standard}', '--ptxas-options=--register-usage-level=10' + (',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''), @@ -132,8 +136,8 @@ class Compiler: flags = cls.flags() # Build signature - enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) - signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}' + enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0)) + signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}' name = f'kernel.{name}.{hash_to_hex(signature)}' path = os.path.join(get_cache_dir(), name) @@ -177,6 +181,10 @@ class NVCCCompiler(Compiler): major, minor = map(int, version.split('.')) return major, minor + @classmethod + def signature(cls) -> str: + return f'nvcc+{cls.__version__()}' + @classmethod def flags(cls) -> List[str]: if platform.system() != 'Windows': @@ -216,6 +224,10 @@ class NVRTCCompiler(Compiler): major, minor = map(int, cuda.bindings.__version__.split('.')[:2]) return major, minor + @classmethod + def signature(cls) -> str: + return f'nvrtc+{cls.__version__()}' + @staticmethod def include_dirs() -> List[str]: if CUDA_HOME is None: @@ -224,13 +236,14 @@ class NVRTCCompiler(Compiler): @classmethod def flags(cls) -> List[str]: - base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], - '--gpu-architecture=sm_90a', '-default-device'] + flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()], + '--gpu-architecture=sm_90a', '-default-device'] + # NOTES: PCH is vital for compilation speed if cls.__version__() >= (12, 8): - base_flags += ['--pch'] + flags += ['--pch'] if int(os.getenv('DG_JIT_DEBUG', 0)): - base_flags += ['--pch-verbose=true'] - return base_flags + flags += ['--pch-verbose=true'] + return flags @classmethod def compile(cls, name: str, code: str, target_path: str) -> None: diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 8860847..a5fe16b 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -176,7 +176,7 @@ class FP8GemmRuntime(Runtime): using namespace deep_gemm; __global__ void dummy_kernel() {{ - void *ptr = (void *)&fp8_gemm_kernel< + auto ptr = reinterpret_cast(&fp8_gemm_kernel< {kwargs['N']}, {kwargs['K']}, {kwargs['BLOCK_M']}, @@ -191,7 +191,7 @@ __global__ void dummy_kernel() {{ {kwargs['NUM_TMA_MULTICAST']}, {'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}, GemmType::{kwargs['GEMM_TYPE']} - >; + >); }} ''' if int(os.getenv('DG_JIT_DEBUG', 0)): diff --git a/tests/test_core.py b/tests/test_core.py index bdc1841..de544c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,3 +1,8 @@ +# PyTorch has its own NVRTC, which may have a lower version than the system +# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch +import cuda.bindings.nvrtc as nvrtc +print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}') + import random import torch from typing import Tuple