Feature:Support flashMLA decoding via flashAttn2(#29)

Changes:
1. Implement flashMLA with matrix absorption algorithm via flashAttn2
2. Add golden test on MXMACA platform
This commit is contained in:
Kevin Zhang
2025-02-24 23:55:21 +08:00
parent bcb90f2afd
commit e0557deb3a
18 changed files with 1197 additions and 702 deletions

View File

@@ -4,7 +4,11 @@ import random
import torch
import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
# from flash_mla import get_mla_metadata, flash_mla_with_kvcache
from flash_attn import (
get_mla_metadata,
flash_mla_with_kvcache
)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@@ -32,12 +36,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, block_size):
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}")
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
@@ -51,7 +55,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
@@ -107,10 +110,10 @@ if __name__ == "__main__":
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
for block_size in [1,4,16,64]:
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1]: # TODO: to support MTP=2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, block_size)