mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Several fixes
This commit is contained in:
parent
317e83581d
commit
83f6e9537e
@ -108,12 +108,13 @@ The library also provides some environment variables, which may be useful:
|
||||
- `DG_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
|
||||
- `DG_DISABLE_CACHE`: 0 or 1, disable the use of cache directory, 0 by default
|
||||
- `DG_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `from torch.utils.cpp_extension.CUDA_HOME` by default
|
||||
- `DG_NVCC_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
|
||||
- `DG_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler
|
||||
- `DG_DISABLE_FFMA_INTERLEAVE`: 0 or 1, disable FFMA-interleaving optimization
|
||||
- `DG_PTXAS_VERBOSE`: 0 or 1, show detailed PTXAS compiler output
|
||||
- `DG_PRINT_REG_REUSE`: 0 or 1, print FFMA-interleaving details
|
||||
- `DG_JIT_PRINT_COMPILER_COMMAND`: 0 or 1, print NVCC compilation command
|
||||
- `DG_JIT_DEBUG`: 0 or 1, print more debugging information
|
||||
- `DG_JIT_USE_NVRTC`: 0 or 1, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, 0 by default
|
||||
|
||||
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
|
||||
|
||||
|
||||
@ -105,6 +105,10 @@ def put(path, data):
|
||||
|
||||
|
||||
class Compiler:
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def __version__() -> Tuple[int, int]:
|
||||
pass
|
||||
@ -115,7 +119,7 @@ class Compiler:
|
||||
|
||||
@staticmethod
|
||||
def flags() -> List[str]:
|
||||
cpp_standard = int(os.getenv('DG_NVCC_OVERRIDE_CPP_STANDARD', 20))
|
||||
cpp_standard = int(os.getenv('DG_OVERRIDE_CPP_STANDARD', 20))
|
||||
return [f'-std=c++{cpp_standard}',
|
||||
'--ptxas-options=--register-usage-level=10' +
|
||||
(',--verbose' if 'DG_PTXAS_VERBOSE' in os.environ else ''),
|
||||
@ -132,8 +136,8 @@ class Compiler:
|
||||
flags = cls.flags()
|
||||
|
||||
# Build signature
|
||||
enable_sass_opt = get_nvcc_compiler()[1] <= '12.8' and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${code}$${get_nvcc_compiler()}$${flags}$${enable_sass_opt}'
|
||||
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_DISABLE_FFMA_INTERLEAVE', 0))
|
||||
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
|
||||
name = f'kernel.{name}.{hash_to_hex(signature)}'
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
|
||||
@ -177,6 +181,10 @@ class NVCCCompiler(Compiler):
|
||||
major, minor = map(int, version.split('.'))
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvcc+{cls.__version__()}'
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
if platform.system() != 'Windows':
|
||||
@ -216,6 +224,10 @@ class NVRTCCompiler(Compiler):
|
||||
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
|
||||
return major, minor
|
||||
|
||||
@classmethod
|
||||
def signature(cls) -> str:
|
||||
return f'nvrtc+{cls.__version__()}'
|
||||
|
||||
@staticmethod
|
||||
def include_dirs() -> List[str]:
|
||||
if CUDA_HOME is None:
|
||||
@ -224,13 +236,14 @@ class NVRTCCompiler(Compiler):
|
||||
|
||||
@classmethod
|
||||
def flags(cls) -> List[str]:
|
||||
base_flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
|
||||
'--gpu-architecture=sm_90a', '-default-device']
|
||||
# NOTES: PCH is vital for compilation speed
|
||||
if cls.__version__() >= (12, 8):
|
||||
base_flags += ['--pch']
|
||||
flags += ['--pch']
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
base_flags += ['--pch-verbose=true']
|
||||
return base_flags
|
||||
flags += ['--pch-verbose=true']
|
||||
return flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> None:
|
||||
|
||||
@ -176,7 +176,7 @@ class FP8GemmRuntime(Runtime):
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
@ -191,7 +191,7 @@ __global__ void dummy_kernel() {{
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
>);
|
||||
}}
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
# PyTorch has its own NVRTC, which may have a lower version than the system
|
||||
# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
|
||||
|
||||
import random
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
Loading…
Reference in New Issue
Block a user