mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
[wip] refactor: compile to .cubin
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import ctypes
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .template import map_ctype
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.args = None
|
||||
self.kernel = None
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@@ -21,29 +23,66 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.args', 'kernel.so']
|
||||
files = ['kernel.cu', 'kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, *args) -> int:
|
||||
# Load SO file
|
||||
if self.lib is None or self.args is None:
|
||||
self.lib = ctypes.CDLL(os.path.join(self.path, 'kernel.so'))
|
||||
with open(os.path.join(self.path, 'kernel.args'), 'r') as f:
|
||||
self.args = eval(f.read())
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.lib is None:
|
||||
start_time = time.time_ns()
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
|
||||
# Check args and launch
|
||||
assert len(args) == len(self.args), f'Expected {len(self.args)} arguments, got {len(args)}'
|
||||
cargs = []
|
||||
for arg, (name, dtype) in zip(args, self.args):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert arg.dtype == dtype, f'Expected tensor dtype `{dtype}` for `{name}`, got `{arg.dtype}`'
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel count: {res}")
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to enumerate kernels: {res}")
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel name: {res}")
|
||||
if b"fp8" in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
assert isinstance(arg, dtype), f'Expected built-in type `{dtype}` for `{name}`, got `{type(arg)}`'
|
||||
cargs.append(map_ctype(arg))
|
||||
raise Exception("Failed to find fp8 gemm kernel")
|
||||
|
||||
return_code = ctypes.c_int(0)
|
||||
self.lib.launch(*cargs, ctypes.byref(return_code))
|
||||
return return_code.value
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
self.kernel,
|
||||
kwargs['NUM_TMA_MULTICAST'],
|
||||
kwargs['M'],
|
||||
kwargs['BLOCK_M'],
|
||||
kwargs['GMEM_D'],
|
||||
kwargs['SCALES_B'],
|
||||
kwargs['GROUPED_LAYOUT'],
|
||||
kwargs['NUM_SMS'],
|
||||
kwargs['SMEM_SIZE'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
kwargs['STREAM'],
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
res = cuda.cuLibraryUnload(self.lib)[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
|
||||
Reference in New Issue
Block a user