feat: make API more general

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
Zihua Wu 2025-04-23 02:34:23 -07:00
parent 6c982791eb
commit 46762b6903
4 changed files with 156 additions and 90 deletions

View File

@ -1,3 +1,3 @@
from .compiler import get_nvcc_compiler, build from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
from .template import generate from .template import generate
from .runtime import Runtime from .runtime import Runtime

View File

@ -7,14 +7,14 @@ import re
import subprocess import subprocess
import time import time
import uuid import uuid
from typing import List, Tuple from typing import List, Tuple, Type
import cuda.bindings import cuda.bindings
import cuda.bindings.nvrtc as nvrtc import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma from . import interleave_ffma
from .runtime import Runtime, RuntimeCache, get_symbol from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
runtime_cache = RuntimeCache() runtime_cache = RuntimeCache()
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def compile(cls, name: str, code: str, target_path: str) -> str: def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
pass pass
@staticmethod @staticmethod
@ -132,7 +132,7 @@ class Compiler(abc.ABC):
return [get_jit_include_dir()] return [get_jit_include_dir()]
@classmethod @classmethod
def build(cls, name: str, code: str) -> Runtime: def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
# Compiler flags # Compiler flags
flags = cls.flags() flags = cls.flags()
include_dirs = cls.include_dirs() include_dirs = cls.include_dirs()
@ -146,10 +146,11 @@ class Compiler(abc.ABC):
# Check runtime cache or file system hit # Check runtime cache or file system hit
global runtime_cache global runtime_cache
if runtime_cache[path] is not None: cached_runtime = runtime_cache.get(path, runtime_cls)
if cached_runtime is not None:
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
print(f'Using cached JIT runtime {name} during build') print(f'Using cached JIT runtime {name} during build')
return runtime_cache[path] return cached_runtime
# Compile into a temporary CU file # Compile into a temporary CU file
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@ -157,7 +158,7 @@ class Compiler(abc.ABC):
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time() start_time = time.time()
kernel_name = cls.compile(name, code, tmp_cubin_path) kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
if os.getenv('DG_JIT_DEBUG', None): if os.getenv('DG_JIT_DEBUG', None):
@ -176,8 +177,9 @@ class Compiler(abc.ABC):
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name') os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
# Put cache and return # Put cache and return
runtime_cache[path] = Runtime(path, kernel_name) runtime = runtime_cls(path, kernel_name)
return runtime_cache[path] runtime_cache[path] = runtime
return runtime
class NvccCompiler(Compiler): class NvccCompiler(Compiler):
@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
f'--compiler-options={",".join(cxx_flags)}'] f'--compiler-options={",".join(cxx_flags)}']
@classmethod @classmethod
def compile(cls, name: str, code: str, target_path: str) -> str: def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
# Write the code # Write the code
path = os.path.join(get_cache_dir(), name) path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu') src_path = os.path.join(path, 'kernel.cu')
@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
assert result.returncode == 0, f'Failed to compile {src_path}' assert result.returncode == 0, f'Failed to compile {src_path}'
# NVCC needs to get the symbol name from the cubin file using `cuobjdump` # NVCC needs to get the symbol name from the cubin file using `cuobjdump`
return get_symbol(target_path, 'fp8_gemm_kernel') return get_symbol(target_path, kernel_name_pattern)
class NvrtcCompiler(Compiler): class NvrtcCompiler(Compiler):
@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
return base_flags return base_flags
@classmethod @classmethod
def compile(cls, name: str, code: str, target_path: str) -> str: def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
code_bytes = bytes(code, 'utf-8') code_bytes = bytes(code, 'utf-8')
res, program = nvrtc.nvrtcCreateProgram( res, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], []) code_bytes, bytes(name, 'utf-8'), 0, [], [])
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
raise Exception(f"Failed to create program: {res}") raise Exception(f"Failed to create program: {res}")
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE) kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
kernel_name = kernel_regex.search(code).group( kernel_name = kernel_regex.search(code).group(
0).replace('\n', '').replace(' ', '') 0).replace('\n', '').replace(' ', '')
res = nvrtc.nvrtcAddNameExpression( res = nvrtc.nvrtcAddNameExpression(
@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
def build(name: str, code: str) -> Runtime: def build(name: str, code: str) -> Runtime:
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
return NvrtcCompiler.build(name, code) return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
else: else:
return NvccCompiler.build(name, code) return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')

View File

@ -2,7 +2,7 @@ import os
import platform import platform
import time import time
import subprocess import subprocess
from typing import Any, Dict, Optional from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
@ -15,7 +15,8 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
raise Exception("CUDA_HOME is not set") raise Exception("CUDA_HOME is not set")
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump' cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path] command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
'-symbols', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE, result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True) stderr=subprocess.PIPE, text=True)
assert result.returncode == 0 assert result.returncode == 0
@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
class Runtime: class Runtime:
def __init__(self, path: str, kernel_name: str) -> None: def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.kernel = None self.kernel = None
self.kernel_name = kernel_name self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@staticmethod @staticmethod
@ -62,7 +64,7 @@ class Runtime:
if self.kernel is not None: if self.kernel is not None:
self.lib = lib self.lib = lib
else: else:
raise Exception("Failed to find fp8 gemm kernel") raise Exception("Failed to find kernel")
end_time = time.time_ns() end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1000 elapsed_time = (end_time - start_time) / 1000
@ -70,21 +72,9 @@ class Runtime:
print( print(
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return run_gemm( return self.caller(
self.kernel, self.kernel,
kwargs['NUM_TMA_MULTICAST'], *[kwargs[arg] for arg in self.args]
kwargs['M'],
kwargs['BLOCK_M'],
kwargs['GMEM_D'],
kwargs['SCALES_B'],
kwargs['GROUPED_LAYOUT'],
kwargs['NUM_SMS'],
kwargs['SMEM_SIZE'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
kwargs['STREAM'],
) )
def __del__(self) -> None: def __del__(self) -> None:
@ -94,22 +84,42 @@ class Runtime:
raise Exception(f"Failed to unload library {self.path}: {res}") raise Exception(f"Failed to unload library {self.path}: {res}")
class Fp8GemmRuntime(Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_gemm, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class RuntimeCache: class RuntimeCache:
def __init__(self) -> None: def __init__(self) -> None:
self.cache = {} self.cache = {}
def __getitem__(self, path: str) -> Optional[Runtime]: def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
# In Python runtime # In Python runtime
if path in self.cache: if path in self.cache:
return self.cache[path] return self.cache[path]
# Already compiled # Already compiled
if os.path.exists(path) and Runtime.is_path_valid(path): if os.path.exists(path) and Runtime.is_path_valid(path):
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read() kernel_name = open(os.path.join(
runtime = Runtime(path, kernel_name) path, 'kernel.cubin.name'), 'r').read()
runtime = runtime_cls(path, kernel_name)
self.cache[path] = runtime self.cache[path] = runtime
return runtime return runtime
return None return None
def __setitem__(self, path, runtime) -> None:
self.cache[path] = runtime

View File

@ -1,64 +1,118 @@
import ctypes
import os import os
import torch import torch
from typing import Any from typing import Any, Dict
import cuda.bindings.driver as cuda
from deep_gemm import jit from deep_gemm import jit
class Capture: def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
def __init__(self) -> None: assert a.shape == b.shape == c.shape
self.read_fd = None assert a.device == b.device == c.device
self.write_fd = None assert a.dim() == 1
self.saved_stdout = None
self.captured = None
def __enter__(self) -> Any: n = a.numel()
self.read_fd, self.write_fd = os.pipe()
self.saved_stdout = os.dup(1)
os.dup2(self.write_fd, 1)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: config = cuda.CUlaunchConfig()
os.dup2(self.saved_stdout, 1) config.gridDimX = (n + 127) // 128
os.close(self.write_fd) config.gridDimY = 1
with os.fdopen(self.read_fd, 'r') as f: config.gridDimZ = 1
self.captured = f.read() config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
def capture(self) -> str: kernelValues = (
return self.captured a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
kernelTypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
return f"""
#ifdef __CUDACC_RTC__
#ifndef NVRTC_JIT_COMPILATION
#define NVRTC_JIT_COMPILATION
#endif
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#endif
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template<typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < N) {{
c[i] = a[i] + b[i];
}}
}}
#ifndef NVRTC_JIT_COMPILATION
__global__ void dummy_kernel() {{
void *ptr = (void *)&vector_add<{kwargs['T']}>;
}}
#endif
"""
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str, kernel_name: str) -> None:
super().__init__(path, kernel_name, run_vector_add, [
'A',
'B',
'C',
'STREAM',
])
if __name__ == '__main__': if __name__ == '__main__':
# Runtime # NVCC
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
# Templates
print('Generated code:') print('Generated code:')
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), code = generate_vector_add(T='float')
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
body = "\n"
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
body += 'std::cout << enable_double_streams << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
code = jit.generate((), args, body)
print(code) print(code)
# Build
print('Building ...') print('Building ...')
func = jit.build('test_func', args, code) func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
# Test correctness a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
print('Running ...') b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') c = torch.empty_like(a)
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') assert ret == cuda.CUresult.CUDA_SUCCESS, ret
with Capture() as capture: ref_output = a + b
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 torch.testing.assert_close(c, ref_output)
output = capture.capture()
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
assert output == ref_output, f'{output=}, {ref_output=}'
print('JIT test passed') print('JIT test for NVCC passed\n')
# NVRTC
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
print('Generated code:')
code = generate_vector_add(T='__nv_bfloat16')
print(code)
print('Building ...')
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
ref_output = a + b
torch.testing.assert_close(c, ref_output)
print('JIT test for NVRTC passed')