minor fix test

This commit is contained in:
lancerts 2025-02-23 20:12:49 -08:00
parent accc1695ee
commit 4fbaa9527c

View File

@ -7,7 +7,7 @@ import triton
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
def scaled_dot_product_attention(query, key, value, is_causal=False):
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)