mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor launch-related structures
This commit is contained in:
@@ -8,11 +8,10 @@ from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, args: List[str] = None) -> None:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@@ -48,8 +47,10 @@ class Runtime:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
illegal_names = ['vprintf', '__instantiate_kernel', '__internal']
|
||||
check_illegal = lambda line: any([name in line for name in illegal_names])
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
||||
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
|
||||
if line.startswith('STT_FUNC') and not check_illegal(line)]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
# Load kernel from the library
|
||||
@@ -62,7 +63,7 @@ class Runtime:
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
|
||||
return self.launch(self.kernel, kwargs)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
|
||||
Reference in New Issue
Block a user