mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-04 03:36:55 +00:00
minor fix test
This commit is contained in:
parent
accc1695ee
commit
4fbaa9527c
@ -7,7 +7,7 @@ import triton
|
|||||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||||
query = query.float()
|
query = query.float()
|
||||||
key = key.float()
|
key = key.float()
|
||||||
value = value.float()
|
value = value.float()
|
||||||
@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
|||||||
q[i].transpose(0, 1),
|
q[i].transpose(0, 1),
|
||||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||||
|
h_q=h_q,
|
||||||
|
h_kv=h_kv,
|
||||||
is_causal=causal,
|
is_causal=causal,
|
||||||
)
|
)
|
||||||
out[i] = O.transpose(0, 1)
|
out[i] = O.transpose(0, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user