From 90289837fc2d445bb94ec58d9cd2bdf17a202648 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Sat, 1 Mar 2025 02:14:42 +0800 Subject: [PATCH] update ut --- tests/test_flash_mla.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 010bda3..0cd173c 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -140,9 +140,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtyp t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) print( f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" )