Refactor runtime

This commit is contained in:
Chenggang Zhao
2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@@ -1,19 +1,18 @@
import os
import time
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cbd
from typing import List, Optional, Type
import cuda.bindings.driver as cuda
class Runtime:
def __init__(self, path: str, kernel_name: str = None,
caller: Callable[..., cuda.CUresult] = None,
args: List[str] = None) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path)
@@ -27,6 +26,14 @@ class Runtime:
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
@staticmethod
def generate(**kwargs) -> str:
raise NotImplemented
@staticmethod
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
raise NotImplemented
def __call__(self, **kwargs) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
@@ -62,7 +69,8 @@ class Runtime:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller(
# noinspection PyArgumentList
return self.launch(
self.kernel,
*[kwargs[arg] for arg in self.args]
)