From 6da94d2d366b022d5251cee28b4422ecf4d95834 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 27 Feb 2025 18:20:57 +0800 Subject: [PATCH] Add extra TMA checks --- deep_gemm/jit_kernels/m_grouped_gemm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 6d6e39b..28f838a 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -160,6 +160,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] global includes, template num_sms = get_num_sms() block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + + # Extra checks for TMA store + if num_groups > 1 and m > block_m: + assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' + args = (lhs, lhs_scales, rhs, rhs_scales, out, masked_m, m, torch.cuda.current_stream(), num_sms, smem_size)