mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-04-23 07:34:31 +00:00
commit
dd1161e396
@ -7,7 +7,7 @@ import triton
|
||||
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
|
||||
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
h_q=h_q,
|
||||
h_kv=h_kv,
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = O.transpose(0, 1)
|
||||
|
Loading…
Reference in New Issue
Block a user