mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: make API more general
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
import platform
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
@@ -13,9 +13,10 @@ from .utils import run_gemm
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
|
||||
|
||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
|
||||
'-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
@@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@@ -62,7 +64,7 @@ class Runtime:
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
raise Exception("Failed to find fp8 gemm kernel")
|
||||
raise Exception("Failed to find kernel")
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
@@ -70,21 +72,9 @@ class Runtime:
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
kwargs['NUM_TMA_MULTICAST'],
|
||||
kwargs['M'],
|
||||
kwargs['BLOCK_M'],
|
||||
kwargs['GMEM_D'],
|
||||
kwargs['SCALES_B'],
|
||||
kwargs['GROUPED_LAYOUT'],
|
||||
kwargs['NUM_SMS'],
|
||||
kwargs['SMEM_SIZE'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
kwargs['STREAM'],
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
@@ -94,22 +84,42 @@ class Runtime:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = Runtime(path, kernel_name)
|
||||
kernel_name = open(os.path.join(
|
||||
path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
return None
|
||||
Reference in New Issue
Block a user