update readme

This commit is contained in:
chenhongmin.will 2025-02-28 22:18:04 +08:00
parent c7143a7bda
commit 9887a5501e
2 changed files with 14 additions and 13 deletions

View File

@ -3,7 +3,7 @@
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- BF16, FP16, E4M3
- Paged kvcache with block size of 64
## Quick start

View File

@ -42,11 +42,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}"
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
@ -73,8 +74,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
cache_seqlens, s_q * h_q // h_kv, h_kv
)
init_dtype = q.dtype
def prepare_fp8_input():
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None
if use_fp8:
nonlocal q, blocked_k, blocked_v
@ -89,8 +91,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
if use_fp8:
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
q = q_fp8
blocked_k = blocked_k_fp8
blocked_v = blocked_v_fp8
@ -110,10 +113,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
)
def ref_mla():
if use_fp8:
q_ = (q.to(torch.float) * descale_q).to(init_dtype)
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype)
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype)
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
@ -158,14 +160,13 @@ def main(torch_dtype):
h_kv = 1
d, dv = 576, 512
causal = False
use_fp8 = torch_dtype == torch.float8_e4m3fn
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype)
if __name__ == "__main__":
@ -183,7 +184,7 @@ if __name__ == "__main__":
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
elif args.dtype = "e4m3":
torch.dtype = torch.float8_e4m3fn
elif args.dtype == "e4m3":
torch_dtype = torch.float8_e4m3fn
main(torch_dtype)