mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor runtime
This commit is contained in:
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str = None,
|
||||
caller: Callable[..., cuda.CUresult] = None,
|
||||
args: List[str] = None) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@@ -27,6 +26,14 @@ class Runtime:
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
raise NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
@@ -62,7 +69,8 @@ class Runtime:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return self.caller(
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(
|
||||
self.kernel,
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user