mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor launch-related structures
This commit is contained in:
@@ -2,6 +2,7 @@ import ctypes
|
||||
import os
|
||||
import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
@@ -12,12 +13,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
|
||||
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
|
||||
assert kwargs['A'].dim() == 1
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.gridDimX = (a.numel() + 127) // 128
|
||||
config.gridDimX = (kwargs['A'].numel() + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.numel(),
|
||||
kwargs['A'].data_ptr(),
|
||||
kwargs['B'].data_ptr(),
|
||||
kwargs['C'].data_ptr(),
|
||||
kwargs['A'].numel(),
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
|
||||
Reference in New Issue
Block a user