DeepGEMM/tests/test_jit.py
Zihua Wu f6198492cb feat: drop support for CUDA<12.3
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
2025-04-25 18:56:40 -07:00

116 lines
3.0 KiB
Python

import ctypes
import os
import torch
from typing import Any, Dict
import cuda.bindings.driver as cuda
from deep_gemm import jit
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
n = a.numel()
config = cuda.CUlaunchConfig()
config.gridDimX = (n + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
kernelValues = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
return f"""
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#endif
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template<typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < N) {{
c[i] = a[i] + b[i];
}}
}}
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
}}
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', run_vector_add, [
'A',
'B',
'C',
'STREAM',
])
if __name__ == '__main__':
# NVCC
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='float')
print(code)
print('Building ...')
func = jit.NvccCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVCC passed\n')
# NVRTC
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVRTC passed')