mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-05 09:08:58 +00:00
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import ctypes
|
|
import os
|
|
import torch
|
|
from typing import Optional
|
|
|
|
from .template import map_ctype
|
|
|
|
|
|
class Runtime:
|
|
def __init__(self, path: str) -> None:
|
|
self.path = path
|
|
self.lib = None
|
|
self.args = None
|
|
|
|
assert self.is_path_valid(self.path)
|
|
|
|
@staticmethod
|
|
def is_path_valid(path: str) -> bool:
|
|
# Exists and is a directory
|
|
if not os.path.exists(path) or not os.path.isdir(path):
|
|
return False
|
|
|
|
# Contains all necessary files
|
|
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
|
|
|
def __call__(self, *args) -> int:
|
|
# Load SO file
|
|
if self.lib is None or self.args is None:
|
|
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
|
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
|
self.args = eval(f.read())
|
|
|
|
# Check args and launch
|
|
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
|
cargs = []
|
|
for arg, (name, dtype) in zip(args, self.args):
|
|
if isinstance(arg, torch.Tensor):
|
|
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
|
else:
|
|
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
|
cargs.append(map_ctype(arg))
|
|
|
|
return_code = ctypes.c_int(0)
|
|
self.lib.launch(*cargs, ctypes.byref(return_code))
|
|
return return_code.value
|
|
|
|
|
|
class RuntimeCache:
|
|
def __init__(self) -> None:
|
|
self.cache = {}
|
|
|
|
def __getitem__(self, path: str) -> Optional[Runtime]:
|
|
# In Python runtime
|
|
if path in self.cache:
|
|
return self.cache[path]
|
|
|
|
# Already compiled
|
|
if os.path.exists(path) and Runtime.is_path_valid(path):
|
|
runtime = Runtime(path)
|
|
self.cache[path] = runtime
|
|
return runtime
|
|
return None
|
|
|
|
def __setitem__(self, path, runtime) -> None:
|
|
self.cache[path] = runtime
|