mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-05 23:44:22 +00:00
feat: make API more general
Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
This commit is contained in:
parent
6c982791eb
commit
46762b6903
@ -1,3 +1,3 @@
|
||||
from .compiler import get_nvcc_compiler, build
|
||||
from .compiler import get_nvcc_compiler, build, NvccCompiler, NvrtcCompiler
|
||||
from .template import generate
|
||||
from .runtime import Runtime
|
||||
|
@ -7,14 +7,14 @@ import re
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import cuda.bindings
|
||||
import cuda.bindings.nvrtc as nvrtc
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
from . import interleave_ffma
|
||||
from .runtime import Runtime, RuntimeCache, get_symbol
|
||||
from .runtime import Runtime, Fp8GemmRuntime, RuntimeCache, get_symbol
|
||||
|
||||
runtime_cache = RuntimeCache()
|
||||
|
||||
@ -115,7 +115,7 @@ class Compiler(abc.ABC):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -132,7 +132,7 @@ class Compiler(abc.ABC):
|
||||
return [get_jit_include_dir()]
|
||||
|
||||
@classmethod
|
||||
def build(cls, name: str, code: str) -> Runtime:
|
||||
def build(cls, name: str, code: str, kernel_name_pattern: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Runtime:
|
||||
# Compiler flags
|
||||
flags = cls.flags()
|
||||
include_dirs = cls.include_dirs()
|
||||
@ -146,10 +146,11 @@ class Compiler(abc.ABC):
|
||||
|
||||
# Check runtime cache or file system hit
|
||||
global runtime_cache
|
||||
if runtime_cache[path] is not None:
|
||||
cached_runtime = runtime_cache.get(path, runtime_cls)
|
||||
if cached_runtime is not None:
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
print(f'Using cached JIT runtime {name} during build')
|
||||
return runtime_cache[path]
|
||||
return cached_runtime
|
||||
|
||||
# Compile into a temporary CU file
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@ -157,7 +158,7 @@ class Compiler(abc.ABC):
|
||||
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
|
||||
|
||||
start_time = time.time()
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path)
|
||||
kernel_name = cls.compile(name, code, tmp_cubin_path, kernel_name_pattern)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
if os.getenv('DG_JIT_DEBUG', None):
|
||||
@ -176,8 +177,9 @@ class Compiler(abc.ABC):
|
||||
os.replace(f'{tmp_cubin_path}.name', f'{cubin_path}.name')
|
||||
|
||||
# Put cache and return
|
||||
runtime_cache[path] = Runtime(path, kernel_name)
|
||||
return runtime_cache[path]
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
runtime_cache[path] = runtime
|
||||
return runtime
|
||||
|
||||
|
||||
class NvccCompiler(Compiler):
|
||||
@ -200,7 +202,7 @@ class NvccCompiler(Compiler):
|
||||
f'--compiler-options={",".join(cxx_flags)}']
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
# Write the code
|
||||
path = os.path.join(get_cache_dir(), name)
|
||||
src_path = os.path.join(path, 'kernel.cu')
|
||||
@ -220,7 +222,7 @@ class NvccCompiler(Compiler):
|
||||
assert result.returncode == 0, f'Failed to compile {src_path}'
|
||||
|
||||
# NVCC needs to get the symbol name from the cubin file using `cuobjdump`
|
||||
return get_symbol(target_path, 'fp8_gemm_kernel')
|
||||
return get_symbol(target_path, kernel_name_pattern)
|
||||
|
||||
|
||||
class NvrtcCompiler(Compiler):
|
||||
@ -249,14 +251,14 @@ class NvrtcCompiler(Compiler):
|
||||
return base_flags
|
||||
|
||||
@classmethod
|
||||
def compile(cls, name: str, code: str, target_path: str) -> str:
|
||||
def compile(cls, name: str, code: str, target_path: str, kernel_name_pattern: str) -> str:
|
||||
code_bytes = bytes(code, 'utf-8')
|
||||
res, program = nvrtc.nvrtcCreateProgram(
|
||||
code_bytes, bytes(name, 'utf-8'), 0, [], [])
|
||||
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
|
||||
raise Exception(f"Failed to create program: {res}")
|
||||
|
||||
kernel_regex = re.compile(r'fp8_gemm_kernel<[\S\s]*?>', re.MULTILINE)
|
||||
kernel_regex = re.compile(kernel_name_pattern, re.MULTILINE)
|
||||
kernel_name = kernel_regex.search(code).group(
|
||||
0).replace('\n', '').replace(' ', '')
|
||||
res = nvrtc.nvrtcAddNameExpression(
|
||||
@ -308,6 +310,6 @@ class NvrtcCompiler(Compiler):
|
||||
|
||||
def build(name: str, code: str) -> Runtime:
|
||||
if os.getenv('DG_JIT_USE_NVRTC', '0') in ['1', 'true', 'True']:
|
||||
return NvrtcCompiler.build(name, code)
|
||||
return NvrtcCompiler.build(name, code, kernel_name_pattern=r'fp8_gemm_kernel<[\S\s]*?>')
|
||||
else:
|
||||
return NvccCompiler.build(name, code)
|
||||
return NvccCompiler.build(name, code, kernel_name_pattern='fp8_gemm_kernel')
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import platform
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
@ -13,9 +13,10 @@ from .utils import run_gemm
|
||||
def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
if CUDA_HOME is None:
|
||||
raise Exception("CUDA_HOME is not set")
|
||||
|
||||
|
||||
cuobjdump_bin = 'cuobjdump.exe' if platform.system() == 'Windows' else 'cuobjdump'
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin), '-symbols', file_path]
|
||||
command = [os.path.join(CUDA_HOME, 'bin', cuobjdump_bin),
|
||||
'-symbols', file_path]
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
assert result.returncode == 0
|
||||
@ -26,12 +27,13 @@ def get_symbol(file_path: str, pattern: str) -> Optional[str]:
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
def __init__(self, path: str, kernel_name: str, caller: Callable[..., cuda.CUresult], args: List[str]) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@staticmethod
|
||||
@ -62,7 +64,7 @@ class Runtime:
|
||||
if self.kernel is not None:
|
||||
self.lib = lib
|
||||
else:
|
||||
raise Exception("Failed to find fp8 gemm kernel")
|
||||
raise Exception("Failed to find kernel")
|
||||
|
||||
end_time = time.time_ns()
|
||||
elapsed_time = (end_time - start_time) / 1000
|
||||
@ -70,21 +72,9 @@ class Runtime:
|
||||
print(
|
||||
f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return run_gemm(
|
||||
return self.caller(
|
||||
self.kernel,
|
||||
kwargs['NUM_TMA_MULTICAST'],
|
||||
kwargs['M'],
|
||||
kwargs['BLOCK_M'],
|
||||
kwargs['GMEM_D'],
|
||||
kwargs['SCALES_B'],
|
||||
kwargs['GROUPED_LAYOUT'],
|
||||
kwargs['NUM_SMS'],
|
||||
kwargs['SMEM_SIZE'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
kwargs['STREAM'],
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
def __del__(self) -> None:
|
||||
@ -94,22 +84,42 @@ class Runtime:
|
||||
raise Exception(f"Failed to unload library {self.path}: {res}")
|
||||
|
||||
|
||||
class Fp8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_gemm, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class RuntimeCache:
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def __getitem__(self, path: str) -> Optional[Runtime]:
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
|
||||
def get(self, path: str, runtime_cls: Type[Runtime] = Fp8GemmRuntime) -> Optional[Runtime]:
|
||||
# In Python runtime
|
||||
if path in self.cache:
|
||||
return self.cache[path]
|
||||
|
||||
# Already compiled
|
||||
if os.path.exists(path) and Runtime.is_path_valid(path):
|
||||
kernel_name = open(os.path.join(path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = Runtime(path, kernel_name)
|
||||
kernel_name = open(os.path.join(
|
||||
path, 'kernel.cubin.name'), 'r').read()
|
||||
runtime = runtime_cls(path, kernel_name)
|
||||
self.cache[path] = runtime
|
||||
return runtime
|
||||
return None
|
||||
|
||||
def __setitem__(self, path, runtime) -> None:
|
||||
self.cache[path] = runtime
|
||||
return None
|
@ -1,64 +1,118 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
|
||||
class Capture:
|
||||
def __init__(self) -> None:
|
||||
self.read_fd = None
|
||||
self.write_fd = None
|
||||
self.saved_stdout = None
|
||||
self.captured = None
|
||||
def run_vector_add(kernel: cuda.CUkernel, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, stream: cuda.CUstream) -> cuda.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self.read_fd, self.write_fd = os.pipe()
|
||||
self.saved_stdout = os.dup(1)
|
||||
os.dup2(self.write_fd, 1)
|
||||
return self
|
||||
n = a.numel()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
os.dup2(self.saved_stdout, 1)
|
||||
os.close(self.write_fd)
|
||||
with os.fdopen(self.read_fd, 'r') as f:
|
||||
self.captured = f.read()
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.gridDimX = (n + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
def capture(self) -> str:
|
||||
return self.captured
|
||||
kernelValues = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
n,
|
||||
)
|
||||
kernelTypes = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (kernelValues, kernelTypes), 0)[0]
|
||||
|
||||
|
||||
def generate_vector_add(**kwargs: Dict[str, Any]) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
#define NVRTC_JIT_COMPILATION
|
||||
#endif
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__global__ void vector_add(T* a, T* b, T* c, uint32_t N) {{
|
||||
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < N) {{
|
||||
c[i] = a[i] + b[i];
|
||||
}}
|
||||
}}
|
||||
|
||||
#ifndef NVRTC_JIT_COMPILATION
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&vector_add<{kwargs['T']}>;
|
||||
}}
|
||||
#endif
|
||||
"""
|
||||
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str, kernel_name: str) -> None:
|
||||
super().__init__(path, kernel_name, run_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Runtime
|
||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||
|
||||
# Templates
|
||||
# NVCC
|
||||
print(f'NVCC compiler version: {jit.NvccCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
||||
body = "\n"
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||
code = jit.generate((), args, body)
|
||||
code = generate_vector_add(T='float')
|
||||
print(code)
|
||||
|
||||
# Build
|
||||
print('Building ...')
|
||||
func = jit.build('test_func', args, code)
|
||||
func = jit.NvccCompiler.build('test_func', code, 'vector_add', VectorAddRuntime)
|
||||
|
||||
# Test correctness
|
||||
print('Running ...')
|
||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
||||
with Capture() as capture:
|
||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
||||
output = capture.capture()
|
||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
||||
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test passed')
|
||||
print('JIT test for NVCC passed\n')
|
||||
|
||||
# NVRTC
|
||||
print(f'NVRTC compiler version: {jit.NvrtcCompiler.__version__()}\n')
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='__nv_bfloat16')
|
||||
print(code)
|
||||
print('Building ...')
|
||||
func = jit.NvrtcCompiler.build('test_func', code, r'vector_add<[\S\s]*?>', VectorAddRuntime)
|
||||
|
||||
a = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn((1024, ), dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
ref_output = a + b
|
||||
torch.testing.assert_close(c, ref_output)
|
||||
|
||||
print('JIT test for NVRTC passed')
|
Loading…
Reference in New Issue
Block a user