mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Some lints and refactor
This commit is contained in:
@@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
def __init__(self, path: str, kernel_name: str = None,
|
||||
caller: Callable[..., cuda.CUresult] = None,
|
||||
args: List[str] = None) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
@@ -27,7 +27,7 @@ class Runtime:
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
@@ -59,9 +59,8 @@ class Runtime:
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
@@ -75,25 +74,6 @@ class Runtime:
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
@@ -101,14 +81,14 @@ class RuntimeCache:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
def get(self, path: str, runtime_cls: Type[Runtime]) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
if not int(os.getenv('DG_DISABLE_CACHE', 0)) and os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = runtime_cls(path)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user