mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
update readme
This commit is contained in:
parent
c7143a7bda
commit
9887a5501e
@ -3,7 +3,7 @@
|
|||||||
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
|
||||||
|
|
||||||
Currently released:
|
Currently released:
|
||||||
- BF16, FP16
|
- BF16, FP16, E4M3
|
||||||
- Paged kvcache with block size of 64
|
- Paged kvcache with block size of 64
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|||||||
@ -42,11 +42,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False):
|
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype):
|
||||||
print(
|
print(
|
||||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}"
|
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||||
if varlen:
|
if varlen:
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
@ -73,8 +74,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
|||||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||||
)
|
)
|
||||||
|
|
||||||
|
init_dtype = q.dtype
|
||||||
def prepare_fp8_input():
|
def prepare_fp8_input():
|
||||||
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
|
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
nonlocal q, blocked_k, blocked_v
|
nonlocal q, blocked_k, blocked_v
|
||||||
@ -89,8 +91,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
|||||||
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
|
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
|
|
||||||
q = q_fp8
|
q = q_fp8
|
||||||
blocked_k = blocked_k_fp8
|
blocked_k = blocked_k_fp8
|
||||||
blocked_v = blocked_v_fp8
|
blocked_v = blocked_v_fp8
|
||||||
@ -110,10 +113,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
|||||||
)
|
)
|
||||||
|
|
||||||
def ref_mla():
|
def ref_mla():
|
||||||
if use_fp8:
|
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype)
|
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||||
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||||
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
|
||||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
@ -158,14 +160,13 @@ def main(torch_dtype):
|
|||||||
h_kv = 1
|
h_kv = 1
|
||||||
d, dv = 576, 512
|
d, dv = 576, 512
|
||||||
causal = False
|
causal = False
|
||||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
|
||||||
|
|
||||||
for b in [128]:
|
for b in [128]:
|
||||||
for s in [4096, 8192]:
|
for s in [4096, 8192]:
|
||||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
||||||
for s_q in [1, 2]: # MTP = 1, 2
|
for s_q in [1, 2]: # MTP = 1, 2
|
||||||
for varlen in [False, True]:
|
for varlen in [False, True]:
|
||||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -183,7 +184,7 @@ if __name__ == "__main__":
|
|||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
if args.dtype == "fp16":
|
if args.dtype == "fp16":
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
elif args.dtype = "e4m3":
|
elif args.dtype == "e4m3":
|
||||||
torch.dtype = torch.float8_e4m3fn
|
torch_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
main(torch_dtype)
|
main(torch_dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user