mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Add swizzling params
This commit is contained in:
parent
2e7e58011b
commit
6078b25424
@ -395,6 +395,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
|
|||||||
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||||
uint32_t BLOCK_N_PADDING,
|
uint32_t BLOCK_N_PADDING,
|
||||||
|
uint32_t kSwizzleDMode,
|
||||||
uint32_t kNumGroups, uint32_t kNumStages,
|
uint32_t kNumGroups, uint32_t kNumStages,
|
||||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||||
GemmType kGemmType>
|
GemmType kGemmType>
|
||||||
|
|||||||
@ -17,12 +17,14 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
|||||||
constexpr auto BLOCK_N = {BLOCK_N};
|
constexpr auto BLOCK_N = {BLOCK_N};
|
||||||
constexpr auto BLOCK_K = 128;
|
constexpr auto BLOCK_K = 128;
|
||||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||||
|
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||||
|
constexpr auto kNumGroups = 1;
|
||||||
constexpr auto kNumStages = {NUM_STAGES};
|
constexpr auto kNumStages = {NUM_STAGES};
|
||||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||||
|
|
||||||
// Make a templated GEMM
|
// Make a templated GEMM
|
||||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, 1, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::Normal>;
|
||||||
|
|
||||||
// Launch kernel
|
// Launch kernel
|
||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
@ -41,15 +43,28 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in
|
|||||||
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_swizzle_mode(block_n: int) -> int:
|
||||||
|
# TODO: remove some candidates if slow
|
||||||
|
elem_size = 2
|
||||||
|
for mode_bytes in (128, 64, 32):
|
||||||
|
if (block_n * elem_size) % mode_bytes == 0:
|
||||||
|
return mode_bytes
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
def get_block_n_padding_for_smem_d(block_n: int) -> int:
|
||||||
|
# NOTES: padding is for solving bank conflicts, but wastes shared memory space
|
||||||
elem_size, requirement = 2, (4, 8)
|
elem_size, requirement = 2, (4, 8)
|
||||||
bank_stride = (block_n * elem_size) // 4
|
bank_stride = (block_n * elem_size) // 4
|
||||||
padding = (requirement[0] - bank_stride) % requirement[1]
|
padding = (requirement[0] - bank_stride) % requirement[1]
|
||||||
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
|
||||||
|
|
||||||
|
|
||||||
def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> int:
|
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128) -> Tuple[int, int, int]:
|
||||||
block_n_padding = get_block_n_padding_for_smem_d(block_n)
|
# Try swizzle first, as it does not waste shared memory
|
||||||
|
swizzle_mode = get_swizzle_mode(block_n)
|
||||||
|
block_n_padding = get_block_n_padding_for_smem_d(block_n) if swizzle_mode == 0 else 0
|
||||||
|
|
||||||
smem_d = block_m * (block_n + block_n_padding) * 2
|
smem_d = block_m * (block_n + block_n_padding) * 2
|
||||||
smem_a_per_stage = block_m * block_k
|
smem_a_per_stage = block_m * block_k
|
||||||
smem_scales_a_per_stage = block_m * 4
|
smem_scales_a_per_stage = block_m * 4
|
||||||
@ -64,13 +79,17 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
|||||||
smem_size += num_stages * smem_b_per_stage
|
smem_size += num_stages * smem_b_per_stage
|
||||||
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
|
||||||
smem_size += smem_barrier
|
smem_size += smem_barrier
|
||||||
return smem_size
|
|
||||||
|
# Swizzle and padding are not compatible
|
||||||
|
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
|
||||||
|
|
||||||
|
return smem_size, swizzle_mode, block_n_padding
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
||||||
Tuple[int, int, int, int, Tuple[int, bool], int]:
|
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
|
||||||
if not is_grouped_contiguous:
|
if not is_grouped_contiguous:
|
||||||
block_ms = (64, 128, 256)
|
block_ms = (64, 128, 256)
|
||||||
else:
|
else:
|
||||||
@ -109,16 +128,17 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
|
|
||||||
# Always pick the longest one
|
# Always pick the longest one
|
||||||
# NOTES: for double B scales, the best number of stages may be reduced
|
# NOTES: for double B scales, the best number of stages may be reduced
|
||||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
|
||||||
stage_candidates = (8, 7, 6, 5, 4, 3)
|
stage_candidates = (8, 7, 6, 5, 4, 3)
|
||||||
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
||||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||||
stage_candidates = (4, 3)
|
stage_candidates = (4, 3)
|
||||||
for num_stages in stage_candidates:
|
for num_stages in stage_candidates:
|
||||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||||
if best_smem_size <= sm90_capacity:
|
if best_smem_config[0] <= sm90_capacity:
|
||||||
best_num_stages = num_stages
|
best_num_stages = num_stages
|
||||||
break
|
break
|
||||||
|
assert best_smem_config is not None
|
||||||
assert best_num_stages is not None
|
assert best_num_stages is not None
|
||||||
|
|
||||||
# Decide the number of TMA multicast and whether broadcast on A
|
# Decide the number of TMA multicast and whether broadcast on A
|
||||||
@ -142,7 +162,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||||
assert num_min_sms <= num_sms
|
assert num_min_sms <= num_sms
|
||||||
|
|
||||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_size
|
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||||
|
|
||||||
|
|
||||||
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
@ -192,12 +212,13 @@ def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms)
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
|
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='gemm_fp8_fp8_bf16_nt',
|
name='gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1]},
|
||||||
|
|||||||
@ -16,13 +16,14 @@ constexpr auto BLOCK_M = {BLOCK_M};
|
|||||||
constexpr auto BLOCK_N = {BLOCK_N};
|
constexpr auto BLOCK_N = {BLOCK_N};
|
||||||
constexpr auto BLOCK_K = 128;
|
constexpr auto BLOCK_K = 128;
|
||||||
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
constexpr auto BLOCK_N_PADDING = {BLOCK_N_PADDING};
|
||||||
|
constexpr auto kSwizzleDMode = {SWIZZLE_D_MODE};
|
||||||
constexpr auto kNumGroups = {NUM_GROUPS};
|
constexpr auto kNumGroups = {NUM_GROUPS};
|
||||||
constexpr auto kNumStages = {NUM_STAGES};
|
constexpr auto kNumStages = {NUM_STAGES};
|
||||||
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
|
||||||
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
|
||||||
|
|
||||||
// Make a templated grouped GEMM
|
// Make a templated grouped GEMM
|
||||||
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
using gemm_t = Gemm<N, K, BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_N_PADDING, kSwizzleDMode, kNumGroups, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
|
||||||
|
|
||||||
// Launch kernel
|
// Launch kernel
|
||||||
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
auto tma_a_desc = gemm_t::make_2d_tma_a_desc(lhs, m);
|
||||||
@ -87,14 +88,15 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
m_indices, m, num_groups,
|
m_indices, m, num_groups,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
'NUM_GROUPS': num_groups,
|
'NUM_GROUPS': num_groups,
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
@ -165,7 +167,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
# Auto-tuning with compilation
|
# Auto-tuning with compilation
|
||||||
global includes, template
|
global includes, template
|
||||||
num_sms = get_num_sms()
|
num_sms = get_num_sms()
|
||||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||||
|
|
||||||
# Extra checks for TMA store
|
# Extra checks for TMA store
|
||||||
if num_groups > 1 and m > block_m:
|
if num_groups > 1 and m > block_m:
|
||||||
@ -173,11 +175,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
|||||||
|
|
||||||
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
args = (lhs, lhs_scales, rhs, rhs_scales, out,
|
||||||
masked_m, m,
|
masked_m, m,
|
||||||
torch.cuda.current_stream(), num_sms, smem_size)
|
torch.cuda.current_stream(), num_sms, smem_config[0])
|
||||||
runtime = jit_tuner.compile_and_tune(
|
runtime = jit_tuner.compile_and_tune(
|
||||||
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
name='m_grouped_gemm_fp8_fp8_bf16_nt',
|
||||||
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n,
|
||||||
'BLOCK_N_PADDING': get_block_n_padding_for_smem_d(block_n),
|
'SWIZZLE_D_MODE': smem_config[1],
|
||||||
|
'BLOCK_N_PADDING': smem_config[2],
|
||||||
'NUM_GROUPS': num_groups,
|
'NUM_GROUPS': num_groups,
|
||||||
'NUM_STAGES': num_stages,
|
'NUM_STAGES': num_stages,
|
||||||
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
'NUM_TMA_MULTICAST': tma_multicast_config[0],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user