mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-18 21:14:37 +00:00
82 lines
3.5 KiB
Python
82 lines
3.5 KiB
Python
import copy
|
|
import os
|
|
import torch
|
|
from typing import Any, Dict
|
|
|
|
from ..jit import build, cpp_format, generate, Runtime
|
|
|
|
|
|
class JITTuner:
|
|
def __init__(self) -> None:
|
|
self.tuned = {}
|
|
|
|
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
|
includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime:
|
|
# NOTES: we always assume the space and template will not change
|
|
# We also assume the GPU device will not be changed
|
|
# NOTES: the function must have no accumulated side effects
|
|
keys = {k: keys[k] for k in sorted(keys.keys())}
|
|
signature = (name, f'{keys}')
|
|
if signature in self.tuned:
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Using cached JIT kernel {name} with keys {keys}')
|
|
return self.tuned[signature]
|
|
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
|
|
|
assert signature not in self.tuned
|
|
assert args is not None
|
|
space = (dict(), ) if len(space) == 0 else space
|
|
|
|
kernels = []
|
|
for tuned_keys in space:
|
|
assert isinstance(tuned_keys, dict)
|
|
full_keys = copy.deepcopy(keys)
|
|
full_keys.update(tuned_keys)
|
|
code = generate(includes, arg_defs, cpp_format(template, full_keys))
|
|
|
|
# Illegal build must raise errors
|
|
kernels.append((build(name, arg_defs, code), tuned_keys))
|
|
|
|
best_runtime, best_time, best_keys = None, None, None
|
|
for runtime, tuned_keys in kernels:
|
|
if len(space) > 1:
|
|
# Check kernel validity
|
|
return_code = runtime(*args)
|
|
if return_code != 0:
|
|
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
|
continue
|
|
|
|
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_()
|
|
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda')
|
|
start_event.record()
|
|
for i in range(20):
|
|
assert runtime(*args) == 0
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
elapsed_time = start_event.elapsed_time(end_event)
|
|
else:
|
|
elapsed_time = 0
|
|
|
|
# Compare if better
|
|
if best_time is None or elapsed_time < best_time:
|
|
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
|
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
|
|
|
# Cache the best runtime and return
|
|
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
|
print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
|
self.tuned[signature] = best_runtime
|
|
return best_runtime
|
|
|
|
|
|
jit_tuner = JITTuner()
|