mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
Fix benchmark script
This commit is contained in:
parent
b31bfe72a8
commit
063ffa8ec1
@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
|
||||
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
|
||||
|
||||
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
|
||||
if target not in ["flash_infer", "flash_mla_triton"]:
|
||||
if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]:
|
||||
# flash_infer has a different lse return value
|
||||
# flash_mla_triton doesn't return lse
|
||||
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
|
||||
|
Loading…
Reference in New Issue
Block a user