mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor runtime
This commit is contained in:
parent
981cc58932
commit
317e83581d
@ -1,19 +1,18 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class Runtime:
|
||||
def __init__(self, path: str, kernel_name: str = None,
|
||||
caller: Callable[..., cuda.CUresult] = None,
|
||||
args: List[str] = None) -> None:
|
||||
self.path = path
|
||||
self.lib = None
|
||||
self.kernel = None
|
||||
self.kernel_name = kernel_name
|
||||
self.caller = caller
|
||||
self.args = args
|
||||
assert self.is_path_valid(self.path)
|
||||
|
||||
@ -27,6 +26,14 @@ class Runtime:
|
||||
files = ['kernel.cubin']
|
||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
raise NotImplemented
|
||||
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||
# Load CUBIN
|
||||
if self.kernel is None:
|
||||
@ -62,7 +69,8 @@ class Runtime:
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||
|
||||
return self.caller(
|
||||
# noinspection PyArgumentList
|
||||
return self.launch(
|
||||
self.kernel,
|
||||
*[kwargs[arg] for arg in self.args]
|
||||
)
|
||||
|
||||
@ -3,8 +3,10 @@ import torch
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
from .runtime import FP8GemmRuntime, generate
|
||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||
|
||||
@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
|
||||
@ -2,8 +2,10 @@ import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .gemm import get_best_configs
|
||||
from .runtime import FP8GemmRuntime, generate
|
||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
||||
from .runtime import (
|
||||
FP8GemmRuntime, GemmType,
|
||||
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||
from .tuner import jit_tuner
|
||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||
|
||||
@ -103,7 +105,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
@ -209,7 +210,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||
space=(),
|
||||
kwargs=kwargs,
|
||||
generator=generate,
|
||||
runtime_cls=FP8GemmRuntime,
|
||||
)
|
||||
|
||||
|
||||
@ -8,66 +8,6 @@ from typing import Any, Dict, Tuple
|
||||
from ..jit.runtime import Runtime
|
||||
|
||||
|
||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
'''
|
||||
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated code:\n{code}')
|
||||
return code
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', launch, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
@ -200,7 +140,67 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'fp8_gemm', [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
code = f'''
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include <deep_gemm/fp8_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
__global__ void dummy_kernel() {{
|
||||
void *ptr = (void *)&fp8_gemm_kernel<
|
||||
{kwargs['N']},
|
||||
{kwargs['K']},
|
||||
{kwargs['BLOCK_M']},
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['BLOCK_N_PADDING']},
|
||||
{kwargs['SWIZZLE_D_MODE']},
|
||||
{kwargs['NUM_GROUPS']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||
GemmType::{kwargs['GEMM_TYPE']}
|
||||
>;
|
||||
}}
|
||||
'''
|
||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||
print(f'Generated FP8 GEMM code:\n{code}')
|
||||
return code
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
|
||||
@ -12,7 +12,7 @@ class JITTuner:
|
||||
self.tuned = {}
|
||||
|
||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||
# NOTES: we always assume the space, template and GPU devices will not change
|
||||
# NOTES: the function must have no accumulated side effects
|
||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||
@ -34,7 +34,7 @@ class JITTuner:
|
||||
assert isinstance(tuned_keys, dict)
|
||||
full_keys = copy.deepcopy(keys)
|
||||
full_keys.update(tuned_keys)
|
||||
code = generator(**kwargs, **full_keys)
|
||||
code = runtime_cls.generate(**kwargs, **full_keys)
|
||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
||||
|
||||
# TODO: fix tuning with space > 1
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import ctypes
|
||||
import os
|
||||
import torch
|
||||
import cuda.bindings.driver as cuda
|
||||
import cuda.bindings.driver as cbd
|
||||
|
||||
from deep_gemm import jit
|
||||
|
||||
@ -10,42 +10,17 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
|
||||
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
|
||||
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def launch_vector_add(kernel: cuda.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cuda.CUstream) -> cuda.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
|
||||
n = a.numel()
|
||||
|
||||
config = cuda.CUlaunchConfig()
|
||||
config.gridDimX = (n + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
n,
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||
|
||||
|
||||
def generate_vector_add(**kwargs) -> str:
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
return f"""
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{
|
||||
}}
|
||||
"""
|
||||
|
||||
# noinspection PyShadowingNames,PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel,
|
||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
assert a.shape == b.shape == c.shape
|
||||
assert a.device == b.device == c.device
|
||||
assert a.dim() == 1
|
||||
|
||||
class VectorAddRuntime(jit.Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, 'vector_add', launch_vector_add, [
|
||||
'A',
|
||||
'B',
|
||||
'C',
|
||||
'STREAM',
|
||||
])
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.gridDimX = (a.numel() + 127) // 128
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = 128
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.hStream = stream
|
||||
|
||||
arg_values = (
|
||||
a.data_ptr(),
|
||||
b.data_ptr(),
|
||||
c.data_ptr(),
|
||||
a.numel(),
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_uint32,
|
||||
)
|
||||
|
||||
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Generated code:')
|
||||
code = generate_vector_add(T='float')
|
||||
code = VectorAddRuntime.generate(T='float')
|
||||
print(code)
|
||||
print()
|
||||
|
||||
@ -100,6 +98,6 @@ if __name__ == '__main__':
|
||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||
c = torch.empty_like(a)
|
||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
||||
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
|
||||
torch.testing.assert_close(c, a + b)
|
||||
print(f'JIT test for {compiler_name} passed\n')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user