fix some bug

This commit is contained in:
Wangzheee
2025-06-20 06:53:24 +00:00
parent d29b20cd16
commit 26a603f518
5 changed files with 6 additions and 29 deletions

View File

@@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
#block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
block_ns = tuple(range(16, 129, 8))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]

View File

@@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor]
else:
m_per_expert_threshold = 32 # H100
if expected_m>= m_per_expert_threshold:
if expected_m> m_per_expert_threshold:
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False)