DeepGEMM/deep_gemm/jit_kernels/tuner.py
2025-02-25 22:52:41 +08:00

82 lines
3.5 KiB
Python

import copy
import os
import torch
from typing import Any, Dict
from ..jit import build, cpp_format, generate, Runtime
class JITTuner:
def __init__(self) -> None:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
# NOTES: we always assume the space and template will not change
# We also assume the GPU device will not be changed
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
signature = (name, f'{keys}')
if signature in self.tuned:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT kernel {name} with keys {keys}')
return self.tuned[signature]
if os.getenv('DG_JIT_DEBUG', None):
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
assert signature not in self.tuned
assert args is not None
space = (dict(), ) if len(space) == 0 else space
kernels = []
for tuned_keys in space:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generate(includes, arg_defs, cpp_format(template, full_keys))
# Illegal build must raise errors
kernels.append((build(name, arg_defs, code), tuned_keys))
best_runtime, best_time, best_keys = None, None, None
for runtime, tuned_keys in kernels:
if len(space) > 1:
# Check kernel validity
return_code = runtime(*args)
if return_code != 0:
# Pass illegal kernels, e.g. insufficient shared memory capacity
if os.getenv('DG_JIT_DEBUG', None):
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
continue
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
start_event.record()
for i in range(20):
assert runtime(*args) == 0
end_event.record()
end_event.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = 0
# Compare if better
if best_time is None or elapsed_time < best_time:
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
if os.getenv('DG_JIT_DEBUG', None):
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
# Cache the best runtime and return
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
self.tuned[signature] = best_runtime
return best_runtime
jit_tuner = JITTuner()