Add DG_PRINT_CONFIGS

This commit is contained in:
Chenggang Zhao
2025-05-15 16:36:40 +08:00
parent 816b39053a
commit 4373af2e82
6 changed files with 26 additions and 14 deletions

View File

@@ -5,7 +5,7 @@ import re
import subprocess
import time
import uuid
from typing import List, Tuple, Type
from typing import Any, Dict, List, Tuple, Type
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
@@ -128,7 +128,7 @@ class Compiler:
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
# Compiler flags
flags = cls.flags()
@@ -140,7 +140,7 @@ class Compiler:
# Check runtime cache or file system hit
global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls)
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
if cached_runtime is not None:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build')
@@ -166,8 +166,8 @@ class Compiler:
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime = runtime_cls(path)
runtime_cache[path] = runtime
runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
assert runtime is not None
return runtime
@@ -279,6 +279,6 @@ class NVRTCCompiler(Compiler):
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime]) -> Runtime:
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls=runtime_cls)
return compiler_cls.build(name, code, runtime_cls, kwargs)