Support TMA multicast on B with m_grouped_gemm_contiguous. (#88)

This commit is contained in:
yukuai26
2025-04-21 09:43:17 +08:00
committed by GitHub
parent 83aa960b9b
commit 891f35adf5
5 changed files with 74 additions and 31 deletions

View File

@@ -146,10 +146,9 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked)
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm,
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]: