Some lints and refactor

This commit is contained in:
Chenggang Zhao
2025-05-06 17:23:35 +08:00
parent 8aff6309d4
commit 981cc58932
18 changed files with 421 additions and 449 deletions

View File

@@ -1,14 +1,19 @@
import ctypes
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cuda
from deep_gemm import jit
# Essential debugging staffs
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
# noinspection PyShadowingNames
def launch_vector_add(kernel: cuda.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
@@ -24,28 +29,25 @@ def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: t
config.blockDimZ = 1
config.hStream = stream
kernelValues = (
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
kernelTypes = (
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
def generate_vector_add(**kwargs) -> str:
return f"""
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
@@ -63,14 +65,14 @@ __global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
}}
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', run_vector_add, [
super().__init__(path, 'vector_add', launch_vector_add, [
'A',
'B',
'C',
@@ -79,38 +81,25 @@ class VectorAddRuntime(jit.Runtime):
if __name__ == '__main__':
# NVCC
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='float')
print(code)
print('Building ...')
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
print()
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
for compiler_name in ('NVCC', 'NVRTC'):
# Get compiler
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
print('JIT test for NVCC passed\n')
# Build
print('Building ...')
func = compiler_cls.build('test_func', code, VectorAddRuntime)
# NVRTC
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVRTC passed')
# Run and check
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')