mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Compatible with CUDA 12.3
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from typing import List, Optional, Type
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str = None,
|
||||
args: List[str] = None) -> None:
|
||||
def __init__(self, path: str, args: List[str] = None) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@@ -34,51 +33,39 @@ class Runtime:
|
||||
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||
def __call__(self, **kwargs) -> cbd.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
start_time = time.time_ns()
|
||||
res, lib = cuda.cuLibraryLoadFromFile(
|
||||
bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8'), [], [], 0, [], [], 0)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to load library: {res}')
|
||||
|
||||
res, kernel_count = cuda.cuLibraryGetKernelCount(lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to get kernel count: {res}')
|
||||
|
||||
res, kernels = cuda.cuLibraryEnumerateKernels(kernel_count, lib)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to enumerate kernels: {res}')
|
||||
# Load CUBIN
|
||||
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
|
||||
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
|
||||
|
||||
for kernel in kernels:
|
||||
res, kernel_name = cuda.cuKernelGetName(kernel)
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to get kernel name: {res}')
|
||||
if bytes(self.kernel_name, encoding='utf-8') in kernel_name:
|
||||
self.kernel = kernel
|
||||
break
|
||||
# Extract the kernel name
|
||||
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
kernel_names = [line.split()[-1] for line in result.stdout.splitlines() if line.startswith('STT_FUNC')]
|
||||
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
|
||||
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
raise Exception('Failed to find required kernel')
|
||||
# Load kernel from the library
|
||||
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
elapsed_time = (end_time - start_time) / 1e6
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(
|
||||
self.kernel,
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
return self.launch(self.kernel, *[kwargs[arg] for arg in self.args])
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.lib is not None:
|
||||
res = cuda.cuLibraryUnload(self.lib)[0]
|
||||
if res != cuda.CUresult.CUDA_SUCCESS:
|
||||
res = cbd.cuLibraryUnload(self.lib)[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to unload library {self.path}: {res}')
|
||||
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', [
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
@@ -175,8 +175,7 @@ class FP8GemmRuntime(Runtime):
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
@@ -192,7 +191,6 @@ __global__ void dummy_kernel() {{
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>);
|
||||
}}
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 GEMM code:\n{code}')
|
||||
|
||||
@@ -12,7 +12,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', [
|
||||
super().__init__(path, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
@@ -31,17 +31,15 @@ class VectorAddRuntime(jit.Runtime):
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
template <typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
|
||||
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < N) {{
|
||||
if (i < n) {{
|
||||
c[i] = a[i] + b[i];
|
||||
}}
|
||||
}}
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<float>);
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
|
||||
Reference in New Issue
Block a user