From 046fab64b775966cd027f9aca774484c2241ba99 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 16:41:44 +0800 Subject: [PATCH] Fix grouped GEMM cases --- deep_gemm/jit_kernels/m_grouped_gemm.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 415fc67..bffe137 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -16,9 +16,10 @@ constexpr auto BLOCK_M = {BLOCK_M}; constexpr auto BLOCK_N = {BLOCK_N}; constexpr auto kNumStages = {NUM_STAGES}; constexpr auto kNumTMAMulticast = {NUM_TMA_MULTICAST}; +constexpr auto kIsTMAMulticastOnA = {IS_TMA_MULTICAST_ON_A}; // Make a templated grouped GEMM -using GemmType = Gemm; +using GemmType = Gemm; // Launch kernel auto tma_a_desc = GemmType::make_2d_tma_a_desc(lhs, m); @@ -84,15 +85,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Ten # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms, - is_grouped_contiguous=True) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(m, n, k, 1, num_sms, is_grouped_contiguous=True) args = (lhs, lhs_scales, rhs, rhs_scales, out, m_indices, m, num_groups, torch.cuda.current_stream(), num_sms, smem_size) runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedContiguous'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedContiguous'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), @@ -158,7 +161,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) # Extra checks for TMA store if num_groups > 1 and m > block_m: @@ -170,7 +173,10 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] runtime = jit_tuner.compile_and_tune( name='m_grouped_gemm_fp8_fp8_bf16_nt', keys={'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'NUM_GROUPS': num_groups, - 'NUM_STAGES': num_stages, 'NUM_TMA_MULTICAST': num_tma_multicast, 'GEMM_TYPE': 'GroupedMasked'}, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': 'GroupedMasked'}, space=(), includes=includes, arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),