mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Refactor launch-related structures
This commit is contained in:
@@ -5,14 +5,10 @@ import torch
|
||||
import cuda.bindings.driver as cbd
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from .utils import get_tma_aligned_size
|
||||
from ..jit.runtime import Runtime
|
||||
|
||||
|
||||
class Layout(enum.Enum):
|
||||
RowMajor = 0
|
||||
ColMajor = 1
|
||||
|
||||
|
||||
class GemmType(enum.Enum):
|
||||
Normal = 0
|
||||
GroupedContiguous = 1
|
||||
@@ -61,19 +57,18 @@ def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int
|
||||
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
|
||||
|
||||
|
||||
def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
gmem_dim: Tuple[cbd.cuuint64_t, cbd.cuuint64_t],
|
||||
stride_in_bytes: cbd.cuuint64_t,
|
||||
smem_dim: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
def make_2d_tma_copy_desc(t: torch.Tensor,
|
||||
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
|
||||
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
|
||||
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
|
||||
tensor_dtype = tmap_type_map[global_address.dtype]
|
||||
tensor_dtype = tmap_type_map[t.dtype]
|
||||
res, tensor_map = cbd.cuTensorMapEncodeTiled(
|
||||
tensor_dtype,
|
||||
2,
|
||||
global_address.data_ptr(),
|
||||
gmem_dim,
|
||||
(stride_in_bytes, ),
|
||||
smem_dim,
|
||||
t.data_ptr(),
|
||||
gmem_dims,
|
||||
(gmem_outer_stride,),
|
||||
smem_dims,
|
||||
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
|
||||
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
|
||||
swizzle_type,
|
||||
@@ -86,90 +81,61 @@ def make_2d_tma_copy_desc(global_address: torch.Tensor,
|
||||
return tensor_map
|
||||
|
||||
|
||||
def make_2d_tma_desc(global_address: torch.Tensor, layout: Layout,
|
||||
gmem_rows: int, gmem_cols: int, gmem_stride: int,
|
||||
smem_rows: int, smem_cols: int,
|
||||
def make_2d_tma_desc(t: torch.Tensor,
|
||||
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
|
||||
smem_inner_dim: int, smem_outer_dim: int,
|
||||
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
|
||||
if layout == Layout.RowMajor:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_cols), cbd.cuuint64_t(gmem_rows))
|
||||
smem_dim = (cbd.cuuint32_t(smem_cols), cbd.cuuint32_t(smem_rows))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
else:
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_rows), cbd.cuuint64_t(gmem_cols))
|
||||
smem_dim = (cbd.cuuint32_t(smem_rows), cbd.cuuint32_t(smem_cols))
|
||||
return make_2d_tma_copy_desc(global_address, gmem_dim, cbd.cuuint64_t(gmem_stride * global_address.element_size()), smem_dim, swizzle_type)
|
||||
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
|
||||
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
|
||||
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
|
||||
|
||||
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_k: int,
|
||||
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_k: int, m_stride: int,
|
||||
block_m: int, block_k: int,
|
||||
num_groups: int, a_stride: int = 0) -> cbd.CUtensorMap:
|
||||
a_stride = shape_k if a_stride == 0 else a_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_k, a_stride,
|
||||
block_m, block_k)
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_k, block_m)
|
||||
|
||||
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_k: int, shape_n: int,
|
||||
block_k: int, block_n: int,
|
||||
num_groups: int, b_stride: int = 0) -> cbd.CUtensorMap:
|
||||
b_stride = shape_k if b_stride == 0 else b_stride
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), b_stride,
|
||||
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_n: int, shape_k: int, n_stride: int,
|
||||
block_n: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
return make_2d_tma_desc(t,
|
||||
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
|
||||
block_k, block_n)
|
||||
|
||||
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, global_address: torch.Tensor,
|
||||
shape_m: int, shape_n: int,
|
||||
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_m: int, shape_n: int, m_stride: int,
|
||||
block_m: int, block_n: int,
|
||||
num_groups: int, swizzle_mode: int, d_stride: int = 0) -> cbd.CUtensorMap:
|
||||
num_groups: int,
|
||||
swizzle_mode: int) -> cbd.CUtensorMap:
|
||||
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
|
||||
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
|
||||
d_stride = shape_n if d_stride == 0 else d_stride
|
||||
return make_2d_tma_desc(global_address, Layout.RowMajor,
|
||||
shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n, d_stride,
|
||||
block_m, block_n if swizzle_mode == 0 else swizzle_mode // global_address.element_size(),
|
||||
return make_2d_tma_desc(t,
|
||||
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
|
||||
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
|
||||
swizzle_type_map[swizzle_mode])
|
||||
|
||||
|
||||
def make_2d_tma_scales_a_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
|
||||
shape_mn: int, shape_k: int,
|
||||
block_mn: int, block_k: int,
|
||||
num_groups: int) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_m = (shape_m + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_m, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_m,
|
||||
block_m, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
def make_2d_tma_scales_b_desc(gemm_type: GemmType, global_address: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, num_groups: int = 1) -> cbd.CUtensorMap:
|
||||
# Make TMA aligned to 16 bytes
|
||||
tma_alignment = 16 / global_address.element_size()
|
||||
shape_n = (shape_n + tma_alignment - 1) // tma_alignment * tma_alignment
|
||||
|
||||
return make_2d_tma_desc(global_address, Layout.ColMajor,
|
||||
shape_n, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_n,
|
||||
block_n, 1, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
|
||||
return make_2d_tma_desc(t,
|
||||
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
|
||||
block_mn, 1,
|
||||
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
|
||||
|
||||
|
||||
class FP8GemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'M',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'SCALES_B',
|
||||
'GROUPED_LAYOUT',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@@ -213,21 +179,16 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_m: int,
|
||||
block_m: int, gmem_d: torch.Tensor, scales_b: torch.Tensor,
|
||||
grouped_layout: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
@@ -237,23 +198,23 @@ static void __instantiate_kernel() {{
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
scales_b.data_ptr(),
|
||||
grouped_layout.data_ptr(),
|
||||
shape_m,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_d,
|
||||
kwargs['SCALES_B'].data_ptr(),
|
||||
kwargs['GROUPED_LAYOUT'].data_ptr(),
|
||||
kwargs['M'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_void_p,
|
||||
@@ -269,20 +230,7 @@ static void __instantiate_kernel() {{
|
||||
|
||||
class FP8WGradGemmRuntime(Runtime):
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__(path, [
|
||||
'NUM_TMA_MULTICAST',
|
||||
'K',
|
||||
'BLOCK_M',
|
||||
'GMEM_D',
|
||||
'NUM_SMS',
|
||||
'SMEM_SIZE',
|
||||
'TENSOR_MAP_A',
|
||||
'TENSOR_MAP_B',
|
||||
'TENSOR_MAP_SCALES_A',
|
||||
'TENSOR_MAP_SCALES_B',
|
||||
'TENSOR_MAP_D',
|
||||
'STREAM',
|
||||
])
|
||||
super().__init__(path)
|
||||
|
||||
@staticmethod
|
||||
def generate(**kwargs) -> str:
|
||||
@@ -309,7 +257,7 @@ static void __instantiate_kernel() {{
|
||||
{kwargs['BLOCK_N']},
|
||||
{kwargs['BLOCK_K']},
|
||||
{kwargs['NUM_STAGES']},
|
||||
{kwargs['LAST_STAGES']},
|
||||
{kwargs['NUM_LAST_STAGES']},
|
||||
{kwargs['NUM_TMA_THREADS']},
|
||||
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
|
||||
{kwargs['NUM_TMA_MULTICAST']},
|
||||
@@ -323,21 +271,16 @@ static void __instantiate_kernel() {{
|
||||
|
||||
# noinspection PyMethodOverriding
|
||||
@staticmethod
|
||||
def launch(kernel: cbd.CUkernel, num_tma_multicast: int, shape_k: int,
|
||||
block_m: int, gmem_d: torch.Tensor, num_sms: int, smem_size: int,
|
||||
tensor_map_a: cbd.CUtensorMap, tensor_map_b: cbd.CUtensorMap,
|
||||
tensor_map_scales_a: cbd.CUtensorMap, tensor_map_scales_b: cbd.CUtensorMap,
|
||||
tensor_map_d: cbd.CUtensorMap,
|
||||
stream: cbd.CUstream) -> cbd.CUresult:
|
||||
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
res = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size, kernel, cbd.CUdevice(gmem_d.device.index))[0]
|
||||
if res != cbd.CUresult.CUDA_SUCCESS:
|
||||
raise Exception(f'Failed to set max dynamic shared memory size: {res}')
|
||||
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
||||
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
|
||||
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
|
||||
|
||||
attr_val = cbd.CUlaunchAttributeValue()
|
||||
attr_val.clusterDim.x = num_tma_multicast
|
||||
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
|
||||
attr_val.clusterDim.y = 1
|
||||
attr_val.clusterDim.z = 1
|
||||
attr = cbd.CUlaunchAttribute()
|
||||
@@ -347,22 +290,22 @@ static void __instantiate_kernel() {{
|
||||
config = cbd.CUlaunchConfig()
|
||||
config.numAttrs = 1
|
||||
config.attrs = [attr]
|
||||
config.gridDimX = num_sms
|
||||
config.gridDimX = kwargs['NUM_SMS']
|
||||
config.gridDimY = 1
|
||||
config.gridDimZ = 1
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, block_m)
|
||||
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
|
||||
config.blockDimY = 1
|
||||
config.blockDimZ = 1
|
||||
config.sharedMemBytes = smem_size
|
||||
config.hStream = stream
|
||||
config.sharedMemBytes = kwargs['SMEM_SIZE']
|
||||
config.hStream = kwargs['STREAM']
|
||||
|
||||
arg_values = (
|
||||
shape_k,
|
||||
tensor_map_a,
|
||||
tensor_map_b,
|
||||
tensor_map_scales_a,
|
||||
tensor_map_scales_b,
|
||||
tensor_map_d,
|
||||
kwargs['K'],
|
||||
kwargs['TENSOR_MAP_A'],
|
||||
kwargs['TENSOR_MAP_B'],
|
||||
kwargs['TENSOR_MAP_SCALES_A'],
|
||||
kwargs['TENSOR_MAP_SCALES_B'],
|
||||
kwargs['TENSOR_MAP_D'],
|
||||
)
|
||||
arg_types = (
|
||||
ctypes.c_uint32,
|
||||
|
||||
Reference in New Issue
Block a user