mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: save kernel name to file
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -36,7 +36,7 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cubin']
|
||||
files = ['kernel.cubin', 'kernel.cubin.name']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
@@ -48,7 +48,6 @@ class Runtime:
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
|
||||
print(f"Kernel name: {self.kernel_name}")
|
||||
res, kernel = cuda.cuLibraryGetKernel(
|
||||
lib, bytes(self.kernel_name, encoding='utf-8'))
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
@@ -101,8 +100,7 @@ class RuntimeCache:
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = get_symbol(os.path.join(
|
||||
path, 'kernel.cubin'), 'fp8_gemm_kernel')
|
||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = Runtime(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
Reference in New Issue
Block a user