mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Some lints and refactor
This commit is contained in:
@@ -1,29 +1,28 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Callable, Dict, Type, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from ..jit import build, generate, Runtime
|
||||
from ..jit import build, Runtime
|
||||
|
||||
|
||||
class JITTuner:
|
||||
def __init__(self) -> None:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, kwargs: Dict[str, Any]) -> Runtime:
|
||||
# NOTES: we always assume the space and template will not change
|
||||
# We also assume the GPU device will not be changed
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||
# NOTES: we always assume the space, template and GPU devices will not change
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
signature = (name, f'{keys}')
|
||||
if signature in self.tuned:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Using cached JIT kernel {name} with keys {keys}')
|
||||
return self.tuned[signature]
|
||||
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Auto-tuning JIT kernel {name} with keys {keys}')
|
||||
|
||||
assert signature not in self.tuned
|
||||
@@ -35,19 +34,19 @@ class JITTuner:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code), full_keys))
|
||||
code = generator(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
||||
|
||||
# TODO: fix tuning with space > 1
|
||||
best_runtime, best_time, best_keys = None, None, None
|
||||
for runtime, tuned_keys in kernels:
|
||||
if len(space) > 1:
|
||||
# Check kernel validity
|
||||
return_code = runtime(**tuned_keys, **kwargs)
|
||||
if return_code != cuda.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g. insufficient shared memory capacity
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
if return_code != cbd.CUresult.CUDA_SUCCESS:
|
||||
# Pass illegal kernels, e.g., insufficient shared memory capacity
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}')
|
||||
continue
|
||||
|
||||
# Measure performance with L2 flush and a large GEMM kernel before to reduce overhead between kernels
|
||||
@@ -59,7 +58,7 @@ class JITTuner:
|
||||
(8192, 8192), dtype=torch.float, device='cuda')
|
||||
start_event.record()
|
||||
for i in range(20):
|
||||
assert runtime(**tuned_keys, **kwargs) == cuda.CUresult.CUDA_SUCCESS
|
||||
assert runtime(**tuned_keys, **kwargs) == cbd.CUresult.CUDA_SUCCESS
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
elapsed_time = start_event.elapsed_time(end_event)
|
||||
@@ -69,13 +68,12 @@ class JITTuner:
|
||||
# Compare if better
|
||||
if best_time is None or elapsed_time < best_time:
|
||||
best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}')
|
||||
assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}'
|
||||
|
||||
# Cache the best runtime and return
|
||||
if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None):
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_AUTOTUNE', 0)):
|
||||
print(
|
||||
f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}')
|
||||
self.tuned[signature] = (best_runtime, best_keys)
|
||||
|
||||
Reference in New Issue
Block a user