Fix JIT tests

This commit is contained in:
Chenggang Zhao
2025-05-16 14:39:58 +08:00
parent 78d8362e7a
commit 391755ada0
6 changed files with 14 additions and 11 deletions

View File

@@ -33,7 +33,7 @@ class Runtime:
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
raise NotImplemented
def __call__(self, kwargs: Dict[str, Any]) -> cbd.CUresult:
def __call__(self, **kwargs) -> cbd.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
@@ -81,17 +81,19 @@ class RuntimeCache:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime],
name: str = '', kwargs: Dict[str, Any] = None) -> Optional[Runtime]:
name: str = '', kwargs: Dict[str, Any] = None,
force_enable_cache: bool = False) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if not int(os.getenv('DG_JIT_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0))
if use_cache and os.path.exists(path) and Runtime.is_path_valid(path):
# Print heuristic for the first time
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
simplified_kwargs = dict()
for key, value in kwargs.items():
for key, value in kwargs.items() if kwargs is not None else dict().items():
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
simplified_kwargs[key] = value