mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-04 05:46:13 +00:00
fix: prevent expected_m from exceeding m in test_core
This commit is contained in:
parent
eec7ab7f03
commit
bc989405fe
@ -123,7 +123,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = random.choice(masked_m_candidates)
|
||||
expected_m = int(masked_m.float().mean()) + 1
|
||||
expected_m = min(int(masked_m.float().mean()) + 1, m)
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
|
||||
|
Loading…
Reference in New Issue
Block a user