diff --git a/benchmark/bench_flash_mla.py b/benchmark/bench_flash_mla.py index 14e1352..95c75f2 100644 --- a/benchmark/bench_flash_mla.py +++ b/benchmark/bench_flash_mla.py @@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" - if target not in ["flash_infer", "flash_mla_triton"]: + if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]: # flash_infer has a different lse return value # flash_mla_triton doesn't return lse torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"