feat: save kernel name to file

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu
2025-04-22 22:23:47 -07:00
parent 767793bf95
commit a3210ac850
2 changed files with 14 additions and 8 deletions

View File

@@ -36,7 +36,7 @@ class Runtime:
return False
# Contains all necessary files
files = ['kernel.cubin']
files = ['kernel.cubin', 'kernel.cubin.name']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
@@ -48,7 +48,6 @@ class Runtime:
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f"Failed to load library: {res}")
print(f"Kernel name: {self.kernel_name}")
res, kernel = cuda.cuLibraryGetKernel(
lib, bytes(self.kernel_name, encoding='utf-8'))
if res != cuda.CUresult.CUDA_SUCCESS:
@@ -101,8 +100,7 @@ class RuntimeCache:
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = get_symbol(os.path.join(
path, 'kernel.cubin'), 'fp8_gemm_kernel')
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
self.cache[path] = runtime
return runtime