mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor runtime
This commit is contained in:
parent
981cc58932
commit
317e83581d
@ -1,19 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type
|
import cuda.bindings.driver as cbd
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cuda
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
def __init__(self, path: str, kernel_name: str = None,
|
def __init__(self, path: str, kernel_name: str = None,
|
||||||
caller: Callable[..., cuda.CUresult] = None,
|
|
||||||
args: List[str] = None) -> None:
|
args: List[str] = None) -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
self.lib = None
|
self.lib = None
|
||||||
self.kernel = None
|
self.kernel = None
|
||||||
self.kernel_name = kernel_name
|
self.kernel_name = kernel_name
|
||||||
self.caller = caller
|
|
||||||
self.args = args
|
self.args = args
|
||||||
assert self.is_path_valid(self.path)
|
assert self.is_path_valid(self.path)
|
||||||
|
|
||||||
@ -27,6 +26,14 @@ class Runtime:
|
|||||||
files = ['kernel.cubin']
|
files = ['kernel.cubin']
|
||||||
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
return all(os.path.exists(os.path.join(path, file)) for file in files)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(**kwargs) -> str:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def launch(kernel: cbd.CUkernel, **kwargs) -> cbd.CUresult:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> cuda.CUresult:
|
def __call__(self, **kwargs) -> cuda.CUresult:
|
||||||
# Load CUBIN
|
# Load CUBIN
|
||||||
if self.kernel is None:
|
if self.kernel is None:
|
||||||
@ -62,7 +69,8 @@ class Runtime:
|
|||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||||
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} us.')
|
||||||
|
|
||||||
return self.caller(
|
# noinspection PyArgumentList
|
||||||
|
return self.launch(
|
||||||
self.kernel,
|
self.kernel,
|
||||||
*[kwargs[arg] for arg in self.args]
|
*[kwargs[arg] for arg in self.args]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,8 +3,10 @@ import torch
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from .runtime import FP8GemmRuntime, generate
|
from .runtime import (
|
||||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
FP8GemmRuntime, GemmType,
|
||||||
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
|
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||||
from .tuner import jit_tuner
|
from .tuner import jit_tuner
|
||||||
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
|
||||||
|
|
||||||
@ -238,7 +240,6 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||||
space=(),
|
space=(),
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
generator=generate,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
runtime_cls=FP8GemmRuntime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,10 @@ import torch
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from .gemm import get_best_configs
|
from .gemm import get_best_configs
|
||||||
from .runtime import FP8GemmRuntime, generate
|
from .runtime import (
|
||||||
from .runtime import GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, make_2d_tma_d_desc, make_2d_tma_scales_a_desc
|
FP8GemmRuntime, GemmType,
|
||||||
|
make_2d_tma_a_desc, make_2d_tma_b_desc,
|
||||||
|
make_2d_tma_d_desc, make_2d_tma_scales_a_desc)
|
||||||
from .tuner import jit_tuner
|
from .tuner import jit_tuner
|
||||||
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
from .utils import get_col_major_tma_aligned_tensor, get_num_sms
|
||||||
|
|
||||||
@ -103,7 +105,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
'GEMM_TYPE': GemmType.GroupedContiguous},
|
'GEMM_TYPE': GemmType.GroupedContiguous},
|
||||||
space=(),
|
space=(),
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
generator=generate,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
runtime_cls=FP8GemmRuntime,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -209,7 +210,6 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
'GEMM_TYPE': GemmType.GroupedMasked},
|
'GEMM_TYPE': GemmType.GroupedMasked},
|
||||||
space=(),
|
space=(),
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
generator=generate,
|
|
||||||
runtime_cls=FP8GemmRuntime,
|
runtime_cls=FP8GemmRuntime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,66 +8,6 @@ from typing import Any, Dict, Tuple
|
|||||||
from ..jit.runtime import Runtime
|
from ..jit.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
def generate(**kwargs: Dict[str, Any]) -> str:
|
|
||||||
code = f'''
|
|
||||||
#ifdef __CUDACC_RTC__
|
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
|
||||||
#else
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <string>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp8.h>
|
|
||||||
|
|
||||||
#include <deep_gemm/fp8_gemm.cuh>
|
|
||||||
|
|
||||||
using namespace deep_gemm;
|
|
||||||
|
|
||||||
__global__ void dummy_kernel() {{
|
|
||||||
void *ptr = (void *)&fp8_gemm_kernel<
|
|
||||||
{kwargs['N']},
|
|
||||||
{kwargs['K']},
|
|
||||||
{kwargs['BLOCK_M']},
|
|
||||||
{kwargs['BLOCK_N']},
|
|
||||||
{kwargs['BLOCK_K']},
|
|
||||||
{kwargs['BLOCK_N_PADDING']},
|
|
||||||
{kwargs['SWIZZLE_D_MODE']},
|
|
||||||
{kwargs['NUM_GROUPS']},
|
|
||||||
{kwargs['NUM_STAGES']},
|
|
||||||
{kwargs['NUM_TMA_THREADS']},
|
|
||||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
|
||||||
{kwargs['NUM_TMA_MULTICAST']},
|
|
||||||
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
|
||||||
GemmType::{kwargs['GEMM_TYPE']}
|
|
||||||
>;
|
|
||||||
}}
|
|
||||||
'''
|
|
||||||
|
|
||||||
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
|
||||||
print(f'Generated code:\n{code}')
|
|
||||||
return code
|
|
||||||
|
|
||||||
|
|
||||||
class FP8GemmRuntime(Runtime):
|
|
||||||
def __init__(self, path: str) -> None:
|
|
||||||
super().__init__(path, 'fp8_gemm', launch, [
|
|
||||||
'NUM_TMA_MULTICAST',
|
|
||||||
'M',
|
|
||||||
'BLOCK_M',
|
|
||||||
'GMEM_D',
|
|
||||||
'SCALES_B',
|
|
||||||
'GROUPED_LAYOUT',
|
|
||||||
'NUM_SMS',
|
|
||||||
'SMEM_SIZE',
|
|
||||||
'TENSOR_MAP_A',
|
|
||||||
'TENSOR_MAP_B',
|
|
||||||
'TENSOR_MAP_SCALES_A',
|
|
||||||
'TENSOR_MAP_D',
|
|
||||||
'STREAM',
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class Layout(enum.Enum):
|
class Layout(enum.Enum):
|
||||||
RowMajor = 0
|
RowMajor = 0
|
||||||
ColMajor = 1
|
ColMajor = 1
|
||||||
@ -200,6 +140,66 @@ def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
|||||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||||
|
|
||||||
|
|
||||||
|
class FP8GemmRuntime(Runtime):
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
super().__init__(path, 'fp8_gemm', [
|
||||||
|
'NUM_TMA_MULTICAST',
|
||||||
|
'M',
|
||||||
|
'BLOCK_M',
|
||||||
|
'GMEM_D',
|
||||||
|
'SCALES_B',
|
||||||
|
'GROUPED_LAYOUT',
|
||||||
|
'NUM_SMS',
|
||||||
|
'SMEM_SIZE',
|
||||||
|
'TENSOR_MAP_A',
|
||||||
|
'TENSOR_MAP_B',
|
||||||
|
'TENSOR_MAP_SCALES_A',
|
||||||
|
'TENSOR_MAP_D',
|
||||||
|
'STREAM',
|
||||||
|
])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate(**kwargs) -> str:
|
||||||
|
code = f'''
|
||||||
|
#ifdef __CUDACC_RTC__
|
||||||
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
|
#else
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <string>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
|
#include <deep_gemm/fp8_gemm.cuh>
|
||||||
|
|
||||||
|
using namespace deep_gemm;
|
||||||
|
|
||||||
|
__global__ void dummy_kernel() {{
|
||||||
|
void *ptr = (void *)&fp8_gemm_kernel<
|
||||||
|
{kwargs['N']},
|
||||||
|
{kwargs['K']},
|
||||||
|
{kwargs['BLOCK_M']},
|
||||||
|
{kwargs['BLOCK_N']},
|
||||||
|
{kwargs['BLOCK_K']},
|
||||||
|
{kwargs['BLOCK_N_PADDING']},
|
||||||
|
{kwargs['SWIZZLE_D_MODE']},
|
||||||
|
{kwargs['NUM_GROUPS']},
|
||||||
|
{kwargs['NUM_STAGES']},
|
||||||
|
{kwargs['NUM_TMA_THREADS']},
|
||||||
|
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||||
|
{kwargs['NUM_TMA_MULTICAST']},
|
||||||
|
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
|
||||||
|
GemmType::{kwargs['GEMM_TYPE']}
|
||||||
|
>;
|
||||||
|
}}
|
||||||
|
'''
|
||||||
|
if int(os.getenv('DG_JIT_DEBUG', 0)):
|
||||||
|
print(f'Generated FP8 GEMM code:\n{code}')
|
||||||
|
return code
|
||||||
|
|
||||||
|
# noinspection PyMethodOverriding
|
||||||
|
@staticmethod
|
||||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
||||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ class JITTuner:
|
|||||||
self.tuned = {}
|
self.tuned = {}
|
||||||
|
|
||||||
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple,
|
||||||
kwargs: Dict[str, Any], generator: Callable[..., str], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
kwargs: Dict[str, Any], runtime_cls: Type[Runtime]) -> Tuple[Runtime, Dict[str, Any]]:
|
||||||
# NOTES: we always assume the space, template and GPU devices will not change
|
# NOTES: we always assume the space, template and GPU devices will not change
|
||||||
# NOTES: the function must have no accumulated side effects
|
# NOTES: the function must have no accumulated side effects
|
||||||
keys = {k: keys[k] for k in sorted(keys.keys())}
|
keys = {k: keys[k] for k in sorted(keys.keys())}
|
||||||
@ -34,7 +34,7 @@ class JITTuner:
|
|||||||
assert isinstance(tuned_keys, dict)
|
assert isinstance(tuned_keys, dict)
|
||||||
full_keys = copy.deepcopy(keys)
|
full_keys = copy.deepcopy(keys)
|
||||||
full_keys.update(tuned_keys)
|
full_keys.update(tuned_keys)
|
||||||
code = generator(**kwargs, **full_keys)
|
code = runtime_cls.generate(**kwargs, **full_keys)
|
||||||
kernels.append((build(name, code, runtime_cls), full_keys))
|
kernels.append((build(name, code, runtime_cls), full_keys))
|
||||||
|
|
||||||
# TODO: fix tuning with space > 1
|
# TODO: fix tuning with space > 1
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import cuda.bindings.driver as cuda
|
import cuda.bindings.driver as cbd
|
||||||
|
|
||||||
from deep_gemm import jit
|
from deep_gemm import jit
|
||||||
|
|
||||||
@ -10,42 +10,17 @@ os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
|
|||||||
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
|
os.environ['DG_DISABLE_CACHE'] = os.getenv('DG_DISABLE_CACHE', '1')
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyShadowingNames
|
class VectorAddRuntime(jit.Runtime):
|
||||||
def launch_vector_add(kernel: cuda.CUkernel,
|
def __init__(self, path: str) -> None:
|
||||||
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
super().__init__(path, 'vector_add', [
|
||||||
stream: cuda.CUstream) -> cuda.CUresult:
|
'A',
|
||||||
assert a.shape == b.shape == c.shape
|
'B',
|
||||||
assert a.device == b.device == c.device
|
'C',
|
||||||
assert a.dim() == 1
|
'STREAM',
|
||||||
|
])
|
||||||
|
|
||||||
n = a.numel()
|
@staticmethod
|
||||||
|
def generate(**kwargs) -> str:
|
||||||
config = cuda.CUlaunchConfig()
|
|
||||||
config.gridDimX = (n + 127) // 128
|
|
||||||
config.gridDimY = 1
|
|
||||||
config.gridDimZ = 1
|
|
||||||
config.blockDimX = 128
|
|
||||||
config.blockDimY = 1
|
|
||||||
config.blockDimZ = 1
|
|
||||||
config.hStream = stream
|
|
||||||
|
|
||||||
arg_values = (
|
|
||||||
a.data_ptr(),
|
|
||||||
b.data_ptr(),
|
|
||||||
c.data_ptr(),
|
|
||||||
n,
|
|
||||||
)
|
|
||||||
arg_types = (
|
|
||||||
ctypes.c_void_p,
|
|
||||||
ctypes.c_void_p,
|
|
||||||
ctypes.c_void_p,
|
|
||||||
ctypes.c_uint32,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cuda.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_vector_add(**kwargs) -> str:
|
|
||||||
return f"""
|
return f"""
|
||||||
#ifdef __CUDACC_RTC__
|
#ifdef __CUDACC_RTC__
|
||||||
#include <deep_gemm/nvrtc_std.cuh>
|
#include <deep_gemm/nvrtc_std.cuh>
|
||||||
@ -69,20 +44,43 @@ __global__ void dummy_kernel() {{
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# noinspection PyShadowingNames,PyMethodOverriding
|
||||||
|
@staticmethod
|
||||||
|
def launch(kernel: cbd.CUkernel,
|
||||||
|
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
|
||||||
|
stream: cbd.CUstream) -> cbd.CUresult:
|
||||||
|
assert a.shape == b.shape == c.shape
|
||||||
|
assert a.device == b.device == c.device
|
||||||
|
assert a.dim() == 1
|
||||||
|
|
||||||
class VectorAddRuntime(jit.Runtime):
|
config = cbd.CUlaunchConfig()
|
||||||
def __init__(self, path: str) -> None:
|
config.gridDimX = (a.numel() + 127) // 128
|
||||||
super().__init__(path, 'vector_add', launch_vector_add, [
|
config.gridDimY = 1
|
||||||
'A',
|
config.gridDimZ = 1
|
||||||
'B',
|
config.blockDimX = 128
|
||||||
'C',
|
config.blockDimY = 1
|
||||||
'STREAM',
|
config.blockDimZ = 1
|
||||||
])
|
config.hStream = stream
|
||||||
|
|
||||||
|
arg_values = (
|
||||||
|
a.data_ptr(),
|
||||||
|
b.data_ptr(),
|
||||||
|
c.data_ptr(),
|
||||||
|
a.numel(),
|
||||||
|
)
|
||||||
|
arg_types = (
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_uint32,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print('Generated code:')
|
print('Generated code:')
|
||||||
code = generate_vector_add(T='float')
|
code = VectorAddRuntime.generate(T='float')
|
||||||
print(code)
|
print(code)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@ -100,6 +98,6 @@ if __name__ == '__main__':
|
|||||||
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
|
||||||
c = torch.empty_like(a)
|
c = torch.empty_like(a)
|
||||||
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
|
||||||
assert ret == cuda.CUresult.CUDA_SUCCESS, ret
|
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
|
||||||
torch.testing.assert_close(c, a + b)
|
torch.testing.assert_close(c, a + b)
|
||||||
print(f'JIT test for {compiler_name} passed\n')
|
print(f'JIT test for {compiler_name} passed\n')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user