Some lints and refactor

This commit is contained in:
Chenggang Zhao
2025-05-06 17:23:35 +08:00
parent 8aff6309d4
commit 981cc58932
18 changed files with 421 additions and 449 deletions

View File

@@ -1,29 +1,28 @@
import copy
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cbd
from typing import Any, Callable, Dict, Type, Tuple
import cuda.bindings.driver as cuda
from ..jit import build, generate, Runtime
from ..jit import build, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if os.getenv('DG_JIT_DEBUG', None):
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
@@ -35,19 +34,19 @@ class JITTuner:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(**kwargs, **full_keys)
kernels.append((build(name, code), full_keys))
code = generator(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cuda.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
if return_code != cbd.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g., insufficient shared memory capacity
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
@@ -59,7 +58,7 @@ class JITTuner:
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
@@ -69,13 +68,12 @@ class JITTuner:
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
print(
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)