mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update ut
This commit is contained in:
parent
9887a5501e
commit
90289837fc
@ -140,9 +140,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtyp
|
|||||||
|
|
||||||
t = triton.testing.do_bench(flash_mla)
|
t = triton.testing.do_bench(flash_mla)
|
||||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
|
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||||
torch.finfo(q.dtype).bits // 8
|
|
||||||
)
|
|
||||||
print(
|
print(
|
||||||
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user