feat: make API more general

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-23 02:34:23 -07:00
parent 6c982791eb
commit 46762b6903
4 changed files with 156 additions and 90 deletions

View File

@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
from .template import generate
from .runtime import Runtime

View File

@ -7,14 +7,14 @@ import re
import subprocess
import time
import uuid
from typing import List, Tuple
from typing import List, Tuple, Type
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache, get_symbol
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache()
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod
@abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
pass
@staticmethod
@ -132,7 +132,7 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str) -> Runtime:
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
# Compiler flags
flags = cls.flags()
include_dirs = cls.include_dirs()
@ -146,10 +146,11 @@ class Compiler(abc.ABC):
# Check runtime cache or file system hit
global runtime_cache
if runtime_cache[path] is not None:
cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path]
return cached_runtime
# Compile into a temporary CU file
os.makedirs(path, exist_ok=True)
@ -157,7 +158,7 @@ class Compiler(abc.ABC):
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time()
kernel_name = cls.compile(name, code, tmp_cubin_path)
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
end_time = time.time()
elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None):
@ -176,8 +177,9 @@ class Compiler(abc.ABC):
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return
runtime_cache[path] = Runtime(path, kernel_name)
return runtime_cache[path]
runtime = runtime_cls(path, kernel_name)
runtime_cache[path] = runtime
return runtime
class NvccCompiler(Compiler):
@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel')
return get_symbol(target_path, kernel_name_pattern)
class NvrtcCompiler(Compiler):
@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
return base_flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> str:
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression(
@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code)
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
else:
return NvccCompiler.build(name, code)
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')

View File

@ -2,7 +2,7 @@ import os
import platform
import time
import subprocess
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME
@ -13,9 +13,10 @@ from .utils import run_gemm
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
if CUDA_HOME is None:
raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
'-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
class Runtime:
def __init__(self, path: str, kernel_name: str) -> None:
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path)
@staticmethod
@ -62,7 +64,7 @@ class Runtime:
if self.kernel is not None:
self.lib = lib
else:
raise Exception("Failed to find fp8 gemm kernel")
raise Exception("Failed to find kernel")
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000
@ -70,21 +72,9 @@ class Runtime:
print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm(
return self.caller(
self.kernel,
kwargs['NUM_TMA_MULTICAST'],
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
*[kwargs[arg] for arg in self.args]
)
def __del__(self) -> None:
@ -94,22 +84,42 @@ class Runtime:
raise Exception(f"Failed to unload library {self.path}: {res}")
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]:
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
runtime = Runtime(path, kernel_name)
kernel_name = open(os.path.join(
path, 'kernel.cubin.name'), 'r').read()
runtime = runtime_cls(path, kernel_name)
self.cache[path] = runtime
return runtime
return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
return None

View File

@ -1,64 +1,118 @@
import ctypes
import os
import torch
from typing import Any
from typing import Any, Dict
import cuda.bindings.driver as cuda
from deep_gemm import jit
class Capture:
def __init__(self) -> None:
self.read_fd = None
self.write_fd = None
self.saved_stdout = None
self.captured = None
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
def __enter__(self) -> Any:
self.read_fd, self.write_fd = os.pipe()
self.saved_stdout = os.dup(1)
os.dup2(self.write_fd, 1)
return self
n = a.numel()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
os.dup2(self.saved_stdout, 1)
os.close(self.write_fd)
with os.fdopen(self.read_fd, 'r') as f:
self.captured = f.read()
config = cuda.CUlaunchConfig()
config.gridDimX = (n + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
def capture(self) -> str:
return self.captured
kernelValues = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
return f"""
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#endif
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template<typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < N) {{
c[i] = a[i] + b[i];
}}
}}
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
}}
#endif
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_vector_add, [
'A',
'B',
'C',
'STREAM',
])
if __name__ == '__main__':
# Runtime
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
# Templates
# NVCC
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
print('Generated code:')
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
body = "\n"
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
body += 'std::cout << enable_double_streams << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
code = jit.generate((), args, body)
code = generate_vector_add(T='float')
print(code)
# Build
print('Building ...')
func = jit.build('test_func', args, code)
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
# Test correctness
print('Running ...')
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
with Capture() as capture:
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
output = capture.capture()
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
assert output == ref_output, f'{output=}, {ref_output=}'
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test passed')
print('JIT test for NVCC passed\n')
# NVRTC
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVRTC passed')