From 46762b690355eb2c7026c571b9f1d2e58d89b293 Mon Sep 17 00:00:00 2001 From: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> Date: Wed, 23 Apr 2025 02:34:23 -0700 Subject: [PATCH] feat: make API more general Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com> --- deep_gemm/jit/__init__.py | 2 +- deep_gemm/jit/compiler.py | 32 +++++---- deep_gemm/jit/runtime.py | 64 ++++++++++------- tests/test_jit.py | 148 ++++++++++++++++++++++++++------------ 4 files changed, 156 insertions(+), 90 deletions(-) diff --git a/deep_gemm/jit/__init__.py b/deep_gemm/jit/__init__.py index 999eafb..8e1ba3a 100644 --- a/deep_gemm/jit/__init__.py +++ b/deep_gemm/jit/__init__.py @@ -1,3 +1,3 @@ -from .compiler import get_nvcc_compiler, build +from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler from .template import generate from .runtime import Runtime diff --git a/deep_gemm/jit/compiler.py b/deep_gemm/jit/compiler.py index 628f652..549be37 100644 --- a/deep_gemm/jit/compiler.py +++ b/deep_gemm/jit/compiler.py @@ -7,14 +7,14 @@ import re import subprocess import time import uuid -from typing import List, Tuple +from typing import List, Tuple, Type import cuda.bindings import cuda.bindings.nvrtc as nvrtc from torch.utils.cpp_extension import CUDA_HOME from . import interleave_ffma -from .runtime import Runtime, RuntimeCache, get_symbol +from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol runtime_cache = RuntimeCache() @@ -115,7 +115,7 @@ class Compiler(abc.ABC): @classmethod @abc.abstractmethod - def compile(cls, name: str, code: str, target_path: str) -> str: + def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: pass @staticmethod @@ -132,7 +132,7 @@ class Compiler(abc.ABC): return [get_jit_include_dir()] @classmethod - def build(cls, name: str, code: str) -> Runtime: + def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime: # Compiler flags flags = cls.flags() include_dirs = cls.include_dirs() @@ -146,10 +146,11 @@ class Compiler(abc.ABC): # Check runtime cache or file system hit global runtime_cache - if runtime_cache[path] is not None: + cached_runtime = runtime_cache.get(path, runtime_cls) + if cached_runtime is not None: if os.getenv('DG_JIT_DEBUG', None): print(f'Using cached JIT runtime {name} during build') - return runtime_cache[path] + return cached_runtime # Compile into a temporary CU file os.makedirs(path, exist_ok=True) @@ -157,7 +158,7 @@ class Compiler(abc.ABC): tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin') start_time = time.time() - kernel_name = cls.compile(name, code, tmp_cubin_path) + kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern) end_time = time.time() elapsed_time = end_time - start_time if os.getenv('DG_JIT_DEBUG', None): @@ -176,8 +177,9 @@ class Compiler(abc.ABC): os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name') # Put cache and return - runtime_cache[path] = Runtime(path, kernel_name) - return runtime_cache[path] + runtime = runtime_cls(path, kernel_name) + runtime_cache[path] = runtime + return runtime class NvccCompiler(Compiler): @@ -200,7 +202,7 @@ class NvccCompiler(Compiler): f'--compiler-options={",".join(cxx_flags)}'] @classmethod - def compile(cls, name: str, code: str, target_path: str) -> str: + def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: # Write the code path = os.path.join(get_cache_dir(), name) src_path = os.path.join(path, 'kernel.cu') @@ -220,7 +222,7 @@ class NvccCompiler(Compiler): assert result.returncode == 0, f'Failed to compile {src_path}' # NVCC needs to get the symbol name from the cubin file using `cuobjdump` - return get_symbol(target_path, 'fp8_gemm_kernel') + return get_symbol(target_path, kernel_name_pattern) class NvrtcCompiler(Compiler): @@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler): return base_flags @classmethod - def compile(cls, name: str, code: str, target_path: str) -> str: + def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str: code_bytes = bytes(code, 'utf-8') res, program = nvrtc.nvrtcCreateProgram( code_bytes, bytes(name, 'utf-8'), 0, [], []) if res != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise Exception(f"Failed to create program: {res}") - kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE) + kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE) kernel_name = kernel_regex.search(code).group( 0).replace('\n', '').replace(' ', '') res = nvrtc.nvrtcAddNameExpression( @@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler): def build(name: str, code: str) -> Runtime: if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']: - return NvrtcCompiler.build(name, code) + return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>') else: - return NvccCompiler.build(name, code) + return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel') diff --git a/deep_gemm/jit/runtime.py b/deep_gemm/jit/runtime.py index 3338151..2588dae 100644 --- a/deep_gemm/jit/runtime.py +++ b/deep_gemm/jit/runtime.py @@ -2,7 +2,7 @@ import os import platform import time import subprocess -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, Type import cuda.bindings.driver as cuda from torch.utils.cpp_extension import CUDA_HOME @@ -13,9 +13,10 @@ from .utils import run_gemm def get_symbol(file_path: str, pattern: str) -> Optional[str]: if CUDA_HOME is None: raise Exception("CUDA_HOME is not set") - + cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump' - command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path] + command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), + '-symbols', file_path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) assert result.returncode == 0 @@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]: class Runtime: - def __init__(self, path: str, kernel_name: str) -> None: + def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None: self.path = path self.lib = None self.kernel = None self.kernel_name = kernel_name - + self.caller = caller + self.args = args assert self.is_path_valid(self.path) @staticmethod @@ -62,7 +64,7 @@ class Runtime: if self.kernel is not None: self.lib = lib else: - raise Exception("Failed to find fp8 gemm kernel") + raise Exception("Failed to find kernel") end_time = time.time_ns() elapsed_time = (end_time - start_time) / 1000 @@ -70,21 +72,9 @@ class Runtime: print( f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') - return run_gemm( + return self.caller( self.kernel, - kwargs['NUM_TMA_MULTICAST'], - kwargs['M'], - kwargs['BLOCK_M'], - kwargs['GMEM_D'], - kwargs['SCALES_B'], - kwargs['GROUPED_LAYOUT'], - kwargs['NUM_SMS'], - kwargs['SMEM_SIZE'], - kwargs['TENSOR_MAP_A'], - kwargs['TENSOR_MAP_B'], - kwargs['TENSOR_MAP_SCALES_A'], - kwargs['TENSOR_MAP_D'], - kwargs['STREAM'], + *[kwargs[arg] for arg in self.args] ) def __del__(self) -> None: @@ -94,22 +84,42 @@ class Runtime: raise Exception(f"Failed to unload library {self.path}: {res}") +class Fp8GemmRuntime(Runtime): + def __init__(self, path: str, kernel_name: str) -> None: + super().__init__(path, kernel_name, run_gemm, [ + 'NUM_TMA_MULTICAST', + 'M', + 'BLOCK_M', + 'GMEM_D', + 'SCALES_B', + 'GROUPED_LAYOUT', + 'NUM_SMS', + 'SMEM_SIZE', + 'TENSOR_MAP_A', + 'TENSOR_MAP_B', + 'TENSOR_MAP_SCALES_A', + 'TENSOR_MAP_D', + 'STREAM', + ]) + + class RuntimeCache: def __init__(self) -> None: self.cache = {} - def __getitem__(self, path: str) -> Optional[Runtime]: + def __setitem__(self, path, runtime) -> None: + self.cache[path] = runtime + + def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]: # In Python runtime if path in self.cache: return self.cache[path] # Already compiled if os.path.exists(path) and Runtime.is_path_valid(path): - kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read() - runtime = Runtime(path, kernel_name) + kernel_name = open(os.path.join( + path, 'kernel.cubin.name'), 'r').read() + runtime = runtime_cls(path, kernel_name) self.cache[path] = runtime return runtime - return None - - def __setitem__(self, path, runtime) -> None: - self.cache[path] = runtime + return None \ No newline at end of file diff --git a/tests/test_jit.py b/tests/test_jit.py index 78bc77b..00bea1a 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,64 +1,118 @@ +import ctypes import os import torch -from typing import Any +from typing import Any, Dict + +import cuda.bindings.driver as cuda from deep_gemm import jit -class Capture: - def __init__(self) -> None: - self.read_fd = None - self.write_fd = None - self.saved_stdout = None - self.captured = None +def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult: + assert a.shape == b.shape == c.shape + assert a.device == b.device == c.device + assert a.dim() == 1 - def __enter__(self) -> Any: - self.read_fd, self.write_fd = os.pipe() - self.saved_stdout = os.dup(1) - os.dup2(self.write_fd, 1) - return self + n = a.numel() - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - os.dup2(self.saved_stdout, 1) - os.close(self.write_fd) - with os.fdopen(self.read_fd, 'r') as f: - self.captured = f.read() + config = cuda.CUlaunchConfig() + config.gridDimX = (n + 127) // 128 + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = 128 + config.blockDimY = 1 + config.blockDimZ = 1 + config.hStream = stream - def capture(self) -> str: - return self.captured + kernelValues = ( + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + n, + ) + kernelTypes = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ) + + return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0] + + +def generate_vector_add(**kwargs: Dict[str, Any]) -> str: + return f""" +#ifdef __CUDACC_RTC__ +#ifndef NVRTC_JIT_COMPILATION +#define NVRTC_JIT_COMPILATION +#endif +#include +#else +#include +#endif + +#include +#include + +template +__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{ + uint32_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < N) {{ + c[i] = a[i] + b[i]; + }} +}} + +#ifndef NVRTC_JIT_COMPILATION +__global__ void dummy_kernel() {{ + void *ptr = (void *)&vector_add<{kwargs['T']}>; +}} +#endif +""" + + +class VectorAddRuntime(jit.Runtime): + def __init__(self, path: str, kernel_name: str) -> None: + super().__init__(path, kernel_name, run_vector_add, [ + 'A', + 'B', + 'C', + 'STREAM', + ]) if __name__ == '__main__': - # Runtime - print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') - - # Templates + # NVCC + print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n') print('Generated code:') - args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), - ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) - body = "\n" - body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' - body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' - body += 'std::cout << reinterpret_cast(out) << std::endl;\n' - body += 'std::cout << enable_double_streams << std::endl;\n' - body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' - code = jit.generate((), args, body) + code = generate_vector_add(T='float') print(code) - - # Build print('Building ...') - func = jit.build('test_func', args, code) + func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime) - # Test correctness - print('Running ...') - fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') - fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') - bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') - with Capture() as capture: - assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 - output = capture.capture() - ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' - assert output == ref_output, f'{output=}, {ref_output=}' + a = torch.randn((1024, ), dtype=torch.float32, device='cuda') + b = torch.randn((1024, ), dtype=torch.float32, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cuda.CUresult.CUDA_SUCCESS, ret + ref_output = a + b + torch.testing.assert_close(c, ref_output) - print('JIT test passed') + print('JIT test for NVCC passed\n') + + # NVRTC + print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n') + print('Generated code:') + code = generate_vector_add(T='__nv_bfloat16') + print(code) + print('Building ...') + func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime) + + a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') + b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda') + c = torch.empty_like(a) + ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) + assert ret == cuda.CUresult.CUDA_SUCCESS, ret + ref_output = a + b + torch.testing.assert_close(c, ref_output) + + print('JIT test for NVRTC passed') \ No newline at end of file