mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
106 lines
4.3 KiB
Python
106 lines
4.3 KiB
Python
import os
|
|
import subprocess
|
|
import time
|
|
import torch
|
|
import cuda.bindings.driver as cbd
|
|
|
|
from typing import Any, Dict, Optional, Type
|
|
from torch.utils.cpp_extension import CUDA_HOME
|
|
|
|
|
|
class Runtime:
|
|
def __init__(self, path: str) -> None:
|
|
self.path = path
|
|
self.lib = None
|
|
self.kernel = None
|
|
assert self.is_path_valid(self.path)
|
|
|
|
@staticmethod
|
|
def is_path_valid(path: str) -> bool:
|
|
# Exists and is a directory
|
|
if not os.path.exists(path) or not os.path.isdir(path):
|
|
return False
|
|
|
|
# Contains all necessary files
|
|
files = ['kernel.cubin']
|
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
|
|
|
@staticmethod
|
|
def generate(kwargs: Dict[str, Any]) -> str:
|
|
raise NotImplemented
|
|
|
|
@staticmethod
|
|
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
|
raise NotImplemented
|
|
|
|
def __call__(self, **kwargs) -> cbd.CUresult:
|
|
# Load CUBIN
|
|
if self.kernel is None:
|
|
start_time = time.time_ns()
|
|
|
|
# Load CUBIN
|
|
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
|
|
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
|
|
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
|
|
|
|
# Extract the kernel name
|
|
# TODO: use `cuda-bindings` API to do this (requires at least 12.8)
|
|
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
|
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
assert result.returncode == 0
|
|
illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail']
|
|
check_illegal = lambda line: any([name in line for name in illegal_names])
|
|
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
|
|
if line.startswith('STT_FUNC') and not check_illegal(line)]
|
|
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
|
|
|
# Load kernel from the library
|
|
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
|
|
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
|
|
|
|
end_time = time.time_ns()
|
|
elapsed_time = (end_time - start_time) / 1e6
|
|
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
|
|
|
# noinspection PyArgumentList
|
|
return self.launch(self.kernel, kwargs)
|
|
|
|
def __del__(self) -> None:
|
|
if self.lib is not None:
|
|
res = cbd.cuLibraryUnload(self.lib)[0]
|
|
if res != cbd.CUresult.CUDA_SUCCESS:
|
|
raise Exception(f'Failed to unload library {self.path}: {res}')
|
|
|
|
|
|
class RuntimeCache:
|
|
def __init__(self) -> None:
|
|
self.cache = {}
|
|
|
|
def __setitem__(self, path: str, runtime: Runtime) -> None:
|
|
self.cache[path] = runtime
|
|
|
|
def get(self, path: str, runtime_cls: Type[Runtime],
|
|
name: str = '', kwargs: Dict[str, Any] = None,
|
|
force_enable_cache: bool = False) -> Optional[Runtime]:
|
|
# In Python runtime
|
|
if path in self.cache:
|
|
return self.cache[path]
|
|
|
|
# Already compiled
|
|
use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0))
|
|
if use_cache and os.path.exists(path) and Runtime.is_path_valid(path):
|
|
# Print heuristic for the first time
|
|
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
|
|
simplified_kwargs = dict()
|
|
for key, value in kwargs.items() if kwargs is not None else dict().items():
|
|
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
|
|
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
|
|
simplified_kwargs[key] = value
|
|
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
|
|
|
|
runtime = runtime_cls(path)
|
|
self.cache[path] = runtime
|
|
return runtime
|
|
return None
|