diff --git a/tests/test_core.py b/tests/test_core.py index b430903..a227c3a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -123,7 +123,7 @@ def test_m_grouped_gemm_masked() -> None: masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): masked_m[j] = random.choice(masked_m_candidates) - expected_m = int(masked_m.float().mean()) + 1 + expected_m = min(int(masked_m.float().mean()) + 1, m) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m) for j in range(num_groups): diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])