mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
feat: compat for old drivers
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
@@ -1,16 +1,31 @@
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from .utils import run_gemm
|
||||
|
||||
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
for line in result.stdout.splitlines():
|
||||
if pattern in line:
|
||||
return line.split()[-1]
|
||||
return None
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str) -> None:
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@@ -21,34 +36,25 @@ class Runtime:
|
||||
return False
|
||||
|
||||
# Contains all necessary files
|
||||
files = ['kernel.cu', 'kernel.cubin']
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
def __call__(self, **kwargs: Dict[str, Any]) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.lib is None:
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to load library: {res}")
|
||||
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
print(f"Kernel name: {self.kernel_name}")
|
||||
res, kernel = cuda.cuLibraryGetKernel(
|
||||
lib, bytes(self.kernel_name, encoding='utf-8'))
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel count: {res}")
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to enumerate kernels: {res}")
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f"Failed to get kernel name: {res}")
|
||||
if b"fp8" in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
raise Exception(f"Failed to get kernel: {res}")
|
||||
|
||||
self.kernel = kernel
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
@@ -95,7 +101,9 @@ class RuntimeCache:
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
runtime = Runtime(path)
|
||||
kernel_name = get_symbol(os.path.join(
|
||||
path, 'kernel.cubin'), 'fp8_gemm_kernel')
|
||||
runtime = Runtime(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user