Refactor runtime

This commit is contained in:
Chenggang Zhao 2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@ -1,19 +1,18 @@
import os
import time
from typing import Any, Callable, Dict, List, Optional, Type
import cuda.bindings.driver as cbd
from typing import List, Optional, Type
import cuda.bindings.driver as cuda
class Runtime:
def __init__(self, path: str, kernel_name: str = None,
caller: Callable[..., cuda.CUresult] = None,
args: List[str] = None) -> None:
self.path = path
self.lib = None
self.kernel = None
self.kernel_name = kernel_name
self.caller = caller
self.args = args
assert self.is_path_valid(self.path)
@ -27,6 +26,14 @@ class Runtime:
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
@staticmethod
def generate(**kwargs) -> str:
raise NotImplemented
@staticmethod
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
raise NotImplemented
def __call__(self, **kwargs) -> cuda.CUresult:
# Load CUBIN
if self.kernel is None:
@ -62,7 +69,8 @@ class Runtime:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller(
# noinspection PyArgumentList
return self.launch(
self.kernel,
*[kwargs[arg] for arg in self.args]
)

View File

@ -3,8 +3,10 @@ import torch
from functools import lru_cache
from typing import Tuple
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)

View File

@ -2,8 +2,10 @@ import torch
from typing import Tuple
from .gemm import get_best_configs
from .runtime import FP8GemmRuntime, generate
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
@ -103,7 +105,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'GEMM_TYPE': GemmType.GroupedContiguous},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)
@ -209,7 +210,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'GEMM_TYPE': GemmType.GroupedMasked},
space=(),
kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime,
)

View File

@ -8,66 +8,6 @@ from typing import Any, Dict, Tuple
from ..jit.runtime import Runtime
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated code:\n{code}')
return code
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', launch, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class Layout(enum.Enum):
RowMajor = 0
ColMajor = 1
@ -200,7 +140,67 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod
def generate(**kwargs) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,

View File

@ -12,7 +12,7 @@ class JITTuner:
self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())}
@ -34,7 +34,7 @@ class JITTuner:
assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys)
code = generator(**kwargs, **full_keys)
code = runtime_cls.generate(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1

View File

@ -1,7 +1,7 @@
import ctypes
import os
import torch
import cuda.bindings.driver as cuda
import cuda.bindings.driver as cbd
from deep_gemm import jit
@ -10,42 +10,17 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
# noinspection PyShadowingNames
def launch_vector_add(kernel: cuda.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cuda.CUstream) -> cuda.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', [
'A',
'B',
'C',
'STREAM',
])
n = a.numel()
config = cuda.CUlaunchConfig()
config.gridDimX = (n + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
def generate_vector_add(**kwargs) -> str:
@staticmethod
def generate(**kwargs) -> str:
return f"""
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{
}}
"""
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'vector_add', launch_vector_add, [
'A',
'B',
'C',
'STREAM',
])
config = cbd.CUlaunchConfig()
config.gridDimX = (a.numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
if __name__ == '__main__':
print('Generated code:')
code = generate_vector_add(T='float')
code = VectorAddRuntime.generate(T='float')
print(code)
print()
@ -100,6 +98,6 @@ if __name__ == '__main__':
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')