Fix grouped GEMM cases

This commit is contained in:
Chenggang Zhao 2025-03-25 16:41:44 +08:00
parent 7768319ffe
commit 046fab64b7

View File

@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M};
constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto BLOCK_N = {BLOCK_N};
constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumStages = {NUM_STAGES};
constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST};
constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A};
// Make a templated grouped GEMM // Make a templated grouped GEMM
using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, GemmType::{GEMM_TYPE}>; using GemmType = Gemm<N, K, BLOCK_M, BLOCK_N, 128, {NUM_GROUPS}, kNumStages, kNumTMAMulticast, kIsTMAMulticastOnA, GemmType::{GEMM_TYPE}>;
// Launch kernel // Launch kernel
auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m);
@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten
# Auto-tuning with compilation # Auto-tuning with compilation
global includes, template global includes, template
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True)
is_grouped_contiguous=True)
args = (lhs, lhs_scales, rhs, rhs_scales, out, args = (lhs, lhs_scales, rhs, rhs_scales, out,
m_indices, m, num_groups, m_indices, m, num_groups,
torch.cuda.current_stream(), num_sms, smem_size) torch.cuda.current_stream(), num_sms, smem_size)
runtime = jit_tuner.compile_and_tune( runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedContiguous'},
space=(), space=(),
includes=includes, includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
# Auto-tuning with compilation # Auto-tuning with compilation
global includes, template global includes, template
num_sms = get_num_sms() num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
# Extra checks for TMA store # Extra checks for TMA store
if num_groups > 1 and m > block_m: if num_groups > 1 and m > block_m:
@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
runtime = jit_tuner.compile_and_tune( runtime = jit_tuner.compile_and_tune(
name='m_grouped_gemm_fp8_fp8_bf16_nt', name='m_grouped_gemm_fp8_fp8_bf16_nt',
keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, 'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': 'GroupedMasked'},
space=(), space=(),
includes=includes, includes=includes,
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),