Add DG_PRINT_CONFIGS

This commit is contained in:
Chenggang Zhao
2025-05-15 16:36:40 +08:00
parent 816b39053a
commit 4373af2e82
6 changed files with 26 additions and 14 deletions

View File

@@ -1,9 +1,11 @@
import copy
import os
import subprocess
import time
import torch
import cuda.bindings.driver as cbd
from typing import List, Optional, Type
from typing import Any, Dict, Optional, Type
from torch.utils.cpp_extension import CUDA_HOME
@@ -79,13 +81,23 @@ class RuntimeCache:
def __setitem__(self, path: str, runtime: Runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
def get(self, path: str, runtime_cls: Type[Runtime],
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
# Print heuristic for the first time
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
simplified_kwargs = dict()
for key, value in kwargs.items():
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
simplified_kwargs[key] = value
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime