mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Some lints and refactor
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any, Dict
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
# Essential debugging staffs
|
||||
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
|
||||
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
|
||||
|
||||
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def launch_vector_add(kernel: cuda.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cuda.CUstream) -> cuda.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
@@ -24,28 +29,25 @@ def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: t
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
kernelValues = (
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
n,
|
||||
)
|
||||
kernelTypes = (
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||
|
||||
|
||||
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
|
||||
def generate_vector_add(**kwargs) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
@@ -63,14 +65,14 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
}}
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', run_vector_add, [
|
||||
super().__init__(path, 'vector_add', launch_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
@@ -79,38 +81,25 @@ class VectorAddRuntime(jit.Runtime):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# NVCC
|
||||
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='float')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
|
||||
print()
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
for compiler_name in ('NVCC', 'NVRTC'):
|
||||
# Get compiler
|
||||
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
|
||||
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
|
||||
|
||||
print('JIT test for NVCC passed\n')
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = compiler_cls.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
# NVRTC
|
||||
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='__nv_bfloat16')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test for NVRTC passed')
|
||||
# Run and check
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
torch.testing.assert_close(c, a + b)
|
||||
print(f'JIT test for {compiler_name} passed\n')
|
||||
|
||||
Reference in New Issue
Block a user