Refactor runtime

This commit is contained in:
Chenggang Zhao 2025-05-06 17:45:42 +08:00
parent 981cc58932
commit 317e83581d
6 changed files with 177 additions and 170 deletions

View File

@ -1,19 +1,18 @@
import os import os
import time import time
from typing import Any, Callable, Dict, List, Optional, Type import cuda.bindings.driver as cbd
from typing import List, Optional, Type
import cuda.bindings.driver as cuda import cuda.bindings.driver as cuda
class Runtime: class Runtime:
def __init__(self, path: str, kernel_name: str = None, def __init__(self, path: str, kernel_name: str = None,
caller: Callable[..., cuda.CUresult] = None,
args: List[str] = None) -> None: args: List[str] = None) -> None:
self.path = path self.path = path
self.lib = None self.lib = None
self.kernel = None self.kernel = None
self.kernel_name = kernel_name self.kernel_name = kernel_name
self.caller = caller
self.args = args self.args = args
assert self.is_path_valid(self.path) assert self.is_path_valid(self.path)
@ -27,6 +26,14 @@ class Runtime:
files = ['kernel.cubin'] files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files) return all(os.path.exists(os.path.join(path, file)) for file in files)
@staticmethod
def generate(**kwargs) -> str:
raise NotImplemented
@staticmethod
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
raise NotImplemented
def __call__(self, **kwargs) -> cuda.CUresult: def __call__(self, **kwargs) -> cuda.CUresult:
# Load CUBIN # Load CUBIN
if self.kernel is None: if self.kernel is None:
@ -62,7 +69,8 @@ class Runtime:
if int(os.getenv('DG_JIT_DEBUG', 0)): if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.') print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
return self.caller( # noinspection PyArgumentList
return self.launch(
self.kernel, self.kernel,
*[kwargs[arg] for arg in self.args] *[kwargs[arg] for arg in self.args]
) )

View File

@ -3,8 +3,10 @@ import torch
from functools import lru_cache from functools import lru_cache
from typing import Tuple from typing import Tuple
from .runtime import FP8GemmRuntime, generate from .runtime import (
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]}, 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
space=(), space=(),
kwargs=kwargs, kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime, runtime_cls=FP8GemmRuntime,
) )

View File

@ -2,8 +2,10 @@ import torch
from typing import Tuple from typing import Tuple
from .gemm import get_best_configs from .gemm import get_best_configs
from .runtime import FP8GemmRuntime, generate from .runtime import (
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
from .tuner import jit_tuner from .tuner import jit_tuner
from .utils import get_col_major_tma_aligned_tensor, get_num_sms from .utils import get_col_major_tma_aligned_tensor, get_num_sms
@ -103,7 +105,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
'GEMM_TYPE': GemmType.GroupedContiguous}, 'GEMM_TYPE': GemmType.GroupedContiguous},
space=(), space=(),
kwargs=kwargs, kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime, runtime_cls=FP8GemmRuntime,
) )
@ -209,7 +210,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
'GEMM_TYPE': GemmType.GroupedMasked}, 'GEMM_TYPE': GemmType.GroupedMasked},
space=(), space=(),
kwargs=kwargs, kwargs=kwargs,
generator=generate,
runtime_cls=FP8GemmRuntime, runtime_cls=FP8GemmRuntime,
) )

View File

@ -8,66 +8,6 @@ from typing import Any, Dict, Tuple
from ..jit.runtime import Runtime from ..jit.runtime import Runtime
def generate(**kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated code:\n{code}')
return code
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', launch, [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
class Layout(enum.Enum): class Layout(enum.Enum):
RowMajor = 0 RowMajor = 0
ColMajor = 1 ColMajor = 1
@ -200,6 +140,66 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path, 'fp8_gemm', [
'NUM_TMA_MULTICAST',
'M',
'BLOCK_M',
'GMEM_D',
'SCALES_B',
'GROUPED_LAYOUT',
'NUM_SMS',
'SMEM_SIZE',
'TENSOR_MAP_A',
'TENSOR_MAP_B',
'TENSOR_MAP_SCALES_A',
'TENSOR_MAP_D',
'STREAM',
])
@staticmethod
def generate(**kwargs) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
__global__ void dummy_kernel() {{
void *ptr = (void *)&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>;
}}
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int, def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor, block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
grouped_layout: torch.Tensor, num_sms: int, smem_size: int, grouped_layout: torch.Tensor, num_sms: int, smem_size: int,

View File

@ -12,7 +12,7 @@ class JITTuner:
self.tuned = {} self.tuned = {}
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]: kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
# NOTES: we always assume the space, template and GPU devices will not change # NOTES: we always assume the space, template and GPU devices will not change
# NOTES: the function must have no accumulated side effects # NOTES: the function must have no accumulated side effects
keys = {k: keys[k] for k in sorted(keys.keys())} keys = {k: keys[k] for k in sorted(keys.keys())}
@ -34,7 +34,7 @@ class JITTuner:
assert isinstance(tuned_keys, dict) assert isinstance(tuned_keys, dict)
full_keys = copy.deepcopy(keys) full_keys = copy.deepcopy(keys)
full_keys.update(tuned_keys) full_keys.update(tuned_keys)
code = generator(**kwargs, **full_keys) code = runtime_cls.generate(**kwargs, **full_keys)
kernels.append((build(name, code, runtime_cls), full_keys)) kernels.append((build(name, code, runtime_cls), full_keys))
# TODO: fix tuning with space > 1 # TODO: fix tuning with space > 1

View File

@ -1,7 +1,7 @@
import ctypes import ctypes
import os import os
import torch import torch
import cuda.bindings.driver as cuda import cuda.bindings.driver as cbd
from deep_gemm import jit from deep_gemm import jit
@ -10,42 +10,17 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1') os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
# noinspection PyShadowingNames class VectorAddRuntime(jit.Runtime):
def launch_vector_add(kernel: cuda.CUkernel, def __init__(self, path: str) -> None:
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, super().__init__(path, 'vector_add', [
stream: cuda.CUstream) -> cuda.CUresult: 'A',
assert a.shape == b.shape == c.shape 'B',
assert a.device == b.device == c.device 'C',
assert a.dim() == 1 'STREAM',
])
n = a.numel() @staticmethod
def generate(**kwargs) -> str:
config = cuda.CUlaunchConfig()
config.gridDimX = (n + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
n,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
def generate_vector_add(**kwargs) -> str:
return f""" return f"""
#ifdef __CUDACC_RTC__ #ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh> #include <deep_gemm/nvrtc_std.cuh>
@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{
}} }}
""" """
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel,
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
stream: cbd.CUstream) -> cbd.CUresult:
assert a.shape == b.shape == c.shape
assert a.device == b.device == c.device
assert a.dim() == 1
class VectorAddRuntime(jit.Runtime): config = cbd.CUlaunchConfig()
def __init__(self, path: str) -> None: config.gridDimX = (a.numel() + 127) // 128
super().__init__(path, 'vector_add', launch_vector_add, [ config.gridDimY = 1
'A', config.gridDimZ = 1
'B', config.blockDimX = 128
'C', config.blockDimY = 1
'STREAM', config.blockDimZ = 1
]) config.hStream = stream
arg_values = (
a.data_ptr(),
b.data_ptr(),
c.data_ptr(),
a.numel(),
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
if __name__ == '__main__': if __name__ == '__main__':
print('Generated code:') print('Generated code:')
code = generate_vector_add(T='float') code = VectorAddRuntime.generate(T='float')
print(code) print(code)
print() print()
@ -100,6 +98,6 @@ if __name__ == '__main__':
b = torch.randn((1024, ), dtype=torch.float32, device='cuda') b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a) c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream) ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cuda.CUresult.CUDA_SUCCESS, ret assert ret == cbd.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b) torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n') print(f'JIT test for {compiler_name} passed\n')