From b4ecf9c3ffbb571b691f24a3b7661edd475565b1 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 7 Apr 2025 14:34:42 +0800 Subject: [PATCH] Fix TMA multicast bugs --- deep_gemm/jit_kernels/gemm.py | 8 +++++--- deep_gemm/jit_kernels/m_grouped_gemm.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index cbd6fc2..3f031c1 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -70,7 +70,8 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, - is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: + is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ + Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: # TODO: for some cases, smaller M block is better, add them into tuning space block_ms = (64 if m <= 64 else 128, ) @@ -118,12 +119,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first + is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked) is_multicast_legal = { 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm, } for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 512 and is_multicast_legal[i] and num_groups == 1: + if m >= 512 and is_multicast_legal[i]: best_tma_multicast_config = (2, i == 'A') break diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 908e9e8..253b6d5 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] # Auto-tuning with compilation global includes, template num_sms = get_num_sms() - num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms) + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True) # Extra checks for TMA store if num_groups > 1 and m > block_m: