fix: prevent expected_m from exceeding m in test_core

This commit is contained in:
xuzhean 2025-02-26 16:55:47 +08:00
parent eec7ab7f03
commit bc989405fe

View File

@ -123,7 +123,7 @@ def test_m_grouped_gemm_masked() -> None:
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = random.choice(masked_m_candidates)
expected_m = int(masked_m.float().mean()) + 1
expected_m = min(int(masked_m.float().mean()) + 1, m)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
for j in range(num_groups):
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])