Add extra TMA checks

This commit is contained in:
Chenggang Zhao 2025-02-27 18:20:57 +08:00
parent ca13ce0fab
commit 6da94d2d36

View File

@ -160,6 +160,11 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
global includes, template
num_sms = get_num_sms()
block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
args = (lhs, lhs_scales, rhs, rhs_scales, out,
masked_m, m,
torch.cuda.current_stream(), num_sms, smem_size)