mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Merge branch 'main' into wgrad-gemm
This commit is contained in:
@@ -1,3 +1,8 @@
|
||||
# PyTorch has its own NVRTC, which may have a lower version than the system
|
||||
# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
|
||||
|
||||
import random
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -1,64 +1,103 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
import cuda.bindings.driver as cbd
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
# Essential debugging staffs
|
||||
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
|
||||
os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template <typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
|
||||
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < n) {{
|
||||
c[i] = a[i] + b[i];
|
||||
}}
|
||||
}}
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
|
||||
}}
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.gridDimX = (a.numel() + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.numel(),
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
code = VectorAddRuntime.generate(T='float')
|
||||
print(code)
|
||||
print()
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
for compiler_name in ('NVCC', 'NVRTC'):
|
||||
# Get compiler
|
||||
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
|
||||
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = compiler_cls.build('test_func', code, VectorAddRuntime)
|
||||
|
||||
print('JIT test passed')
|
||||
# Run and check
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
|
||||
torch.testing.assert_close(c, a + b)
|
||||
print(f'JIT test for {compiler_name} passed\n')
|
||||
|
||||
Reference in New Issue
Block a user