Merge branch 'main' into wgrad-gemm

This commit is contained in:
Zhean Xu
2025-05-09 12:30:13 +08:00
21 changed files with 1091 additions and 822 deletions

View File

@@ -1,3 +1,8 @@
# PyTorch has its own NVRTC, which may have a lower version than the system
# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch
import cuda.bindings.nvrtc as nvrtc
print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
import random
import torch
from typing import List, Tuple

View File

@@ -1,64 +1,103 @@
import ctypes
import os
import torch
from typing import Any
import cuda.bindings.driver as cbd
from deep_gemm import jit
# Essential debugging staffs
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class Capture:
def __init__(self) -> None:
self.read_fd = None
self.write_fd = None
self.saved_stdout = None
self.captured = None
def __enter__(self) -> Any:
self.read_fd, self.write_fd = os.pipe()
self.saved_stdout = os.dup(1)
os.dup2(self.write_fd, 1)
return self
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, [
'A',
'B',
'C',
'STREAM',
])
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
os.dup2(self.saved_stdout, 1)
os.close(self.write_fd)
with os.fdopen(self.read_fd, 'r') as f:
self.captured = f.read()
@staticmethod
def generate(**kwargs) -> str:
return f"""
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#endif
def capture(self) -> str:
return self.captured
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template <typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < n) {{
c[i] = a[i] + b[i];
}}
}}
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
"""
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
if __name__ == '__main__':
# Runtime
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
# Templates
print('Generated code:')
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
body = "\n"
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
body += 'std::cout << enable_double_streams << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
code = jit.generate((), args, body)
code = VectorAddRuntime.generate(T='float')
print(code)
print()
# Build
print('Building ...')
func = jit.build('test_func', args, code)
for compiler_name in ('NVCC', 'NVRTC'):
# Get compiler
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
# Test correctness
print('Running ...')
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
with Capture() as capture:
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
output = capture.capture()
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
assert output == ref_output, f'{output=}, {ref_output=}'
# Build
print('Building ...')
func = compiler_cls.build('test_func', code, VectorAddRuntime)
print('JIT test passed')
# Run and check
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')