DeepGEMM/deep_gemm/jit_kernels/tuner.py
Zihua Wu 27cd276e19 [wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
2025-04-22 08:08:40 +00:00

86 lines
3.6 KiB
Python

import copy
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cuda
from ..jit import build, generate, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if os.getenv('DG_JIT_DEBUG', None):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert kwargs is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(**kwargs, **full_keys)
kernels.append((build(name, code), full_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(**tuned_keys, **kwargs)
if return_code != cuda.CUresult.CUDA_SUCCESS:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int,
device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
(8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
print(
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = (best_runtime, best_keys)
return best_runtime, best_keys
jit_tuner = JITTuner()