mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-07 06:35:01 +00:00
86 lines
3.6 KiB
Python
86 lines
3.6 KiB
Python
import copy
|
|
import os
|
|
import torch
|
|
from typing import Any, Dict
|
|
|
|
import cuda.bindings.driver as cuda
|
|
|
|
from ..jit import build, generate, Runtime
|
|
|
|
|
|
class JITTuner:
|
|
def __init__(self) -> None:
|
|
self.tuned = {}
|
|
|
|
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
|
|
# NOTES: we always assume the space and template will not change
|
|
# We also assume the GPU device will not be changed
|
|
# NOTES: the function must have no accumulated side effects
|
|
keys = {k: keys[k] for k in sorted(keys.keys())}
|
|
signature = (name, f'{keys}')
|
|
if signature in self.tuned:
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Using cached JIT kernel {name} with keys {keys}')
|
|
return self.tuned[signature]
|
|
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
|
|
|
assert signature not in self.tuned
|
|
assert kwargs is not None
|
|
space = (dict(), ) if len(space) == 0 else space
|
|
|
|
kernels = []
|
|
for tuned_keys in space:
|
|
assert isinstance(tuned_keys, dict)
|
|
full_keys = copy.deepcopy(keys)
|
|
full_keys.update(tuned_keys)
|
|
code = generate(**kwargs, **full_keys)
|
|
kernels.append((build(name, code), full_keys))
|
|
|
|
best_runtime, best_time, best_keys = None, None, None
|
|
for runtime, tuned_keys in kernels:
|
|
if len(space) > 1:
|
|
# Check kernel validity
|
|
return_code = runtime(**tuned_keys, **kwargs)
|
|
if return_code != cuda.CUresult.CUDA_SUCCESS:
|
|
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(
|
|
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
|
continue
|
|
|
|
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
torch.empty(int(256e6 // 4), dtype=torch.int,
|
|
device='cuda').zero_()
|
|
torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn(
|
|
(8192, 8192), dtype=torch.float, device='cuda')
|
|
start_event.record()
|
|
for i in range(20):
|
|
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
elapsed_time = start_event.elapsed_time(end_event)
|
|
else:
|
|
elapsed_time = 0
|
|
|
|
# Compare if better
|
|
if best_time is None or elapsed_time < best_time:
|
|
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
|
if os.getenv('DG_JIT_DEBUG', None):
|
|
print(
|
|
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
|
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
|
|
|
# Cache the best runtime and return
|
|
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
|
print(
|
|
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
|
self.tuned[signature] = (best_runtime, best_keys)
|
|
return best_runtime, best_keys
|
|
|
|
|
|
jit_tuner = JITTuner()
|