mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add DG_PRINT_CONFIGS
This commit is contained in:
parent
816b39053a
commit
4373af2e82
@ -123,7 +123,7 @@ The library also provides some environment variables, which may be useful:
|
||||
- Post optimization
|
||||
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
|
||||
- Heuristic selection
|
||||
- `DG_PRINT_HEURISTIC`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
|
||||
- Testing
|
||||
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple, Type
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
@ -128,7 +128,7 @@ class Compiler:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
@ -140,7 +140,7 @@ class Compiler:
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
if cached_runtime is not None:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
@ -166,8 +166,8 @@ class Compiler:
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cls(path)
|
||||
runtime_cache[path] = runtime
|
||||
runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
assert runtime is not None
|
||||
return runtime
|
||||
|
||||
|
||||
@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler):
|
||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
|
||||
return compiler_cls.build(name, code, runtime_cls, kwargs)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import copy
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
|
||||
from typing import List, Optional, Type
|
||||
from typing import Any, Dict, Optional, Type
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
@ -79,13 +81,23 @@ class RuntimeCache:
|
||||
def __setitem__(self, path: str, runtime: Runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
|
||||
def get(self, path: str, runtime_cls: Type[Runtime],
|
||||
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
# Print heuristic for the first time
|
||||
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
|
||||
simplified_kwargs = dict()
|
||||
for key, value in kwargs.items():
|
||||
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
|
||||
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
|
||||
simplified_kwargs[key] = value
|
||||
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
|
||||
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
@ -238,5 +238,5 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -102,7 +102,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
@ -201,5 +201,5 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8GemmRuntime.generate(**kwargs)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime)
|
||||
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
@ -111,7 +111,7 @@ def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
# Generate, build and run the kernel
|
||||
code = FP8WGradGemmRuntime.generate(**kwargs)
|
||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime)
|
||||
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
|
||||
runtime(**kwargs)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user