Refactor launch-related structures

This commit is contained in:
Chenggang Zhao
2025-05-15 16:14:21 +08:00
parent e2d6a107ef
commit 816b39053a
9 changed files with 199 additions and 396 deletions

View File

@@ -2,6 +2,7 @@ import ctypes
import os
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict
from deep_gemm import jit
@@ -12,12 +13,7 @@ os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'A',
'B',
'C',
'STREAM',
])
super().__init__(path)
@staticmethod
def generate(**kwargs) -> str:
@@ -46,27 +42,25 @@ static void __instantiate_kernel() {{
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
assert kwargs['A'].dim() == 1
config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128
config.gridDimX = (kwargs['A'].numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
config.hStream = kwargs['STREAM']
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
kwargs['A'].data_ptr(),
kwargs['B'].data_ptr(),
kwargs['C'].data_ptr(),
kwargs['A'].numel(),
)
arg_types = (
ctypes.c_void_p,