DeepGEMM/deep_gemm/jit/runtime.py
Zihua Wu f6198492cb feat: drop support for CUDA<12.3
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
2025-04-25 18:56:40 -07:00

114 lines
3.8 KiB
Python

import os
import time
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from .utils import run_gemm
class Runtime:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
def is_path_valid(path: str) -> bool:
# Exists and is a directory
if not os.path.exists(path) or not os.path.isdir(path):
return False
# Contains all necessary files
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
res, lib = cuda.cuLibraryLoadFromFile(
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to load library: {res}')
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel count: {res}')
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to enumerate kernels: {res}')
for kernel in kernels:
res, kernel_name = cuda.cuKernelGetName(kernel)
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to get kernel name: {res}')
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
self.kernel = kernel
break
if self.kernel is not None:
self.lib = lib
else:
raise Exception('Failed to find required kernel')
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
if os.getenv('DG_JIT_DEBUG', None):
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller(
self.kernel,
*[kwargs[arg] for arg in self.args]
)
def __del__(self) -> None:
if self.lib is not None:
res = cuda.cuLibraryUnload(self.lib)[0]
if res != cuda.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to unload library {self.path}: {res}')
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None