Merge pull request #12 from sazczmh/main

tests: Triton 3.2.0 had remove the fast_flush parameter from do_bench
This commit is contained in:
Jiashi Li 2025-02-24 11:57:41 +08:00 committed by GitHub
commit accc1695ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,7 +87,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla, fast_flush=False)
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")