Several fixes

This commit is contained in:
Chenggang Zhao 2025-05-07 09:57:39 +08:00
parent 317e83581d
commit 83f6e9537e
4 changed files with 30 additions and 11 deletions

View File

@ -108,12 +108,13 @@ The library also provides some environment variables, which may be useful:
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
- `DG_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details
- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command
- `DG_JIT_DEBUG`: 0 or 1, print more debugging information
- `DG_JIT_USE_NVRTC`: 0 or 1, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, 0 by default
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.

View File

@ -105,6 +105,10 @@ def put(path, data):
class Compiler:
@classmethod
def signature(cls) -> str:
pass
@staticmethod
def __version__() -> Tuple[int, int]:
pass
@ -115,7 +119,7 @@ class Compiler:
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
cpp_standard = int(os.getenv('DG_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
@ -132,8 +136,8 @@ class Compiler:
flags = cls.flags()
# Build signature
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
@ -177,6 +181,10 @@ class NVCCCompiler(Compiler):
major, minor = map(int, version.split('.'))
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvcc+{cls.__version__()}'
@classmethod
def flags(cls) -> List[str]:
if platform.system() != 'Windows':
@ -216,6 +224,10 @@ class NVRTCCompiler(Compiler):
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvrtc+{cls.__version__()}'
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
@ -224,13 +236,14 @@ class NVRTCCompiler(Compiler):
@classmethod
def flags(cls) -> List[str]:
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
base_flags += ['--pch']
flags += ['--pch']
if int(os.getenv('DG_JIT_DEBUG', 0)):
base_flags += ['--pch-verbose=true']
return base_flags
flags += ['--pch-verbose=true']
return flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:

View File

@ -176,7 +176,7 @@ class FP8GemmRuntime(Runtime):
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
@ -191,7 +191,7 @@ __global__ void dummy_kernel() {{
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
>);
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):

View File

@ -1,3 +1,8 @@
# PyTorch has its own NVRTC, which may have a lower version than the system
# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch
import cuda.bindings.nvrtc as nvrtc
print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
import random
import torch
from typing import Tuple