mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update readme
This commit is contained in:
parent
c7143a7bda
commit
9887a5501e
@ -3,7 +3,7 @@
|
||||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
||||
|
||||
Currently released:
|
||||
- BF16, FP16
|
||||
- BF16, FP16, E4M3
|
||||
- Paged kvcache with block size of 64
|
||||
|
||||
## Quick start
|
||||
|
@ -42,11 +42,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype):
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}"
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||
)
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
@ -73,8 +74,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
init_dtype = q.dtype
|
||||
def prepare_fp8_input():
|
||||
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
|
||||
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None
|
||||
|
||||
if use_fp8:
|
||||
nonlocal q, blocked_k, blocked_v
|
||||
@ -89,8 +91,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
|
||||
|
||||
|
||||
|
||||
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
|
||||
if use_fp8:
|
||||
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
|
||||
q = q_fp8
|
||||
blocked_k = blocked_k_fp8
|
||||
blocked_v = blocked_v_fp8
|
||||
@ -110,10 +113,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
)
|
||||
|
||||
def ref_mla():
|
||||
if use_fp8:
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype)
|
||||
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
@ -158,14 +160,13 @@ def main(torch_dtype):
|
||||
h_kv = 1
|
||||
d, dv = 576, 512
|
||||
causal = False
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
|
||||
for b in [128]:
|
||||
for s in [4096, 8192]:
|
||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||
for s_q in [1, 2]: # MTP = 1, 2
|
||||
for varlen in [False, True]:
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -183,7 +184,7 @@ if __name__ == "__main__":
|
||||
torch_dtype = torch.bfloat16
|
||||
if args.dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif args.dtype = "e4m3":
|
||||
torch.dtype = torch.float8_e4m3fn
|
||||
elif args.dtype == "e4m3":
|
||||
torch_dtype = torch.float8_e4m3fn
|
||||
|
||||
main(torch_dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user