mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Fix TMA multicast bugs
This commit is contained in:
parent
bff5724ded
commit
b4ecf9c3ff
@ -70,7 +70,8 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
|
||||
Tuple[int, int, int, int, Tuple[int, bool], int]:
|
||||
if not is_grouped_contiguous:
|
||||
# TODO: for some cases, smaller M block is better, add them into tuning space
|
||||
block_ms = (64 if m <= 64 else 128, )
|
||||
@ -118,12 +119,13 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
is_dense_gemm = (not is_grouped_contiguous) and (not is_grouped_masked)
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and is_dense_gemm,
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i] and num_groups == 1:
|
||||
if m >= 512 and is_multicast_legal[i]:
|
||||
best_tma_multicast_config = (2, i == 'A')
|
||||
break
|
||||
|
||||
|
@ -166,7 +166,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor]
|
||||
# Auto-tuning with compilation
|
||||
global includes, template
|
||||
num_sms = get_num_sms()
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms)
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_size = get_best_configs(expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
|
||||
|
||||
# Extra checks for TMA store
|
||||
if num_groups > 1 and m > block_m:
|
||||
|
Loading…
Reference in New Issue
Block a user