mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
reorg ut
This commit is contained in:
parent
bfe38ab106
commit
4e055a6142
@ -63,8 +63,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
|||||||
blocked_v = blocked_k[..., :dv]
|
blocked_v = blocked_k[..., :dv]
|
||||||
|
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||||
|
|
||||||
|
|
||||||
|
init_dtype = q.dtype
|
||||||
|
|
||||||
def prepare_fp8_input():
|
def prepare_fp8_input():
|
||||||
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
|
q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
|
||||||
|
|
||||||
@ -78,33 +79,36 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
|||||||
blocked_k_fp8 = blocked_k.to(fp8_dtype)
|
blocked_k_fp8 = blocked_k.to(fp8_dtype)
|
||||||
blocked_v_fp8 = blocked_v.to(fp8_dtype)
|
blocked_v_fp8 = blocked_v.to(fp8_dtype)
|
||||||
|
|
||||||
q = q_fp8.to(q.dtype) * descale_q
|
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
|
||||||
blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k
|
|
||||||
blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k
|
|
||||||
return q_fp8, blocked_k_fp8, descale_q, descale_k
|
|
||||||
|
|
||||||
|
|
||||||
q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input()
|
if use_fp8:
|
||||||
|
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
|
||||||
|
q = q_fp8
|
||||||
|
blocked_k = blocked_k_fp8
|
||||||
|
blocked_v = blocked_v_fp8
|
||||||
|
|
||||||
def flash_mla():
|
def flash_mla():
|
||||||
q_ = q; blocked_k_ = blocked_k
|
|
||||||
if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8
|
|
||||||
return flash_mla_with_kvcache(
|
return flash_mla_with_kvcache(
|
||||||
q_, blocked_k_, block_table, cache_seqlens, dv,
|
q, blocked_k, block_table, cache_seqlens, dv,
|
||||||
tile_scheduler_metadata, num_splits, causal=causal,
|
tile_scheduler_metadata, num_splits, causal=causal,
|
||||||
descale_q=descale_q, descale_k=descale_k,
|
descale_q=descale_q, descale_k=descale_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def ref_mla():
|
def ref_mla():
|
||||||
|
if use_fp8:
|
||||||
|
q_ = (q.to(torch.float) * descale_q).to(init_dtype)
|
||||||
|
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||||
|
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||||
for i in range(b):
|
for i in range(b):
|
||||||
begin = i * max_seqlen_pad
|
begin = i * max_seqlen_pad
|
||||||
end = begin + cache_seqlens[i]
|
end = begin + cache_seqlens[i]
|
||||||
O, LSE = scaled_dot_product_attention(
|
O, LSE = scaled_dot_product_attention(
|
||||||
q[i].transpose(0, 1),
|
q_[i].transpose(0, 1),
|
||||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||||
h_q=h_q,
|
h_q=h_q,
|
||||||
h_kv=h_kv,
|
h_kv=h_kv,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
|
Loading…
Reference in New Issue
Block a user