Refactor runtime

This commit is contained in:
Chenggang Zhao
2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@@ -1,7 +1,7 @@
import ctypes
import os
import torch
import cuda.bindings.driver as cuda
import cuda.bindings.driver as cbd
from deep_gemm import jit
@@ -10,43 +10,18 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
# noinspection PyShadowingNames
def launch_vector_add(kernel: cuda.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', [
'A',
'B',
'C',
'STREAM',
])
n = a.numel()
config = cuda.CUlaunchConfig()
config.gridDimX = (n + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
def generate_vector_add(**kwargs) -> str:
return f"""
@staticmethod
def generate(**kwargs) -> str:
return f"""
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
@@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{
}}
"""
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', launch_vector_add, [
'A',
'B',
'C',
'STREAM',
])
config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
if __name__ == '__main__':
print('Generated code:')
code = generate_vector_add(T='float')
code = VectorAddRuntime.generate(T='float')
print(code)
print()
@@ -100,6 +98,6 @@ if __name__ == '__main__':
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')