mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add DG_PRINT_CONFIGS
This commit is contained in:
@@ -5,7 +5,7 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple, Type
|
||||
from typing import Any, Dict, List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
@@ -128,7 +128,7 @@ class Compiler:
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
|
||||
@@ -140,7 +140,7 @@ class Compiler:
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
if cached_runtime is not None:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
@@ -166,8 +166,8 @@ class Compiler:
|
||||
os.replace(tmp_cubin_path, cubin_path)
|
||||
|
||||
# Put cache and return
|
||||
runtime = runtime_cls(path)
|
||||
runtime_cache[path] = runtime
|
||||
runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
|
||||
assert runtime is not None
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler):
|
||||
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
|
||||
|
||||
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
|
||||
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
|
||||
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
|
||||
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
|
||||
return compiler_cls.build(name, code, runtime_cls, kwargs)
|
||||
|
||||
Reference in New Issue
Block a user