Refactor runtime

This commit is contained in:
Chenggang Zhao
2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@@ -12,7 +12,7 @@ class JITTuner:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
@@ -34,7 +34,7 @@ class JITTuner:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generator(**kwargs, **full_keys)
code = runtime_cls.generate(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1