Refactor launch-related structures

This commit is contained in:
Chenggang Zhao
2025-05-15 16:14:21 +08:00
parent e2d6a107ef
commit 816b39053a
9 changed files with 199 additions and 396 deletions

View File

@@ -8,11 +8,10 @@ from torch.utils.cpp_extension import CUDA_HOME
class Runtime:
def __init__(self, path: str, args: List[str] = None) -> None:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.kernel = None
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
@@ -48,8 +47,10 @@ class Runtime:
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
illegal_names = ['vprintf', '__instantiate_kernel', '__internal']
check_illegal = lambda line: any([name in line for name in illegal_names])
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
if line.startswith('STT_FUNC') and '__instantiate_kernel' not in line]
if line.startswith('STT_FUNC') and not check_illegal(line)]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library
@@ -62,7 +63,7 @@ class Runtime:
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
return self.launch(self.kernel, kwargs)
def __del__(self) -> None:
if self.lib is not None: