Merge pull request #14 from lancerts/minor-fix

minor fix test
This commit is contained in:
Jiashi Li 2025-02-24 13:13:58 +08:00 committed by GitHub
commit dd1161e396
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,7 +7,7 @@ import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
def scaled_dot_product_attention(query, key, value, is_causal=False):
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)