diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 9c5cd90..8db5db0 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -7,7 +7,7 @@ import triton from flash_mla import get_mla_metadata, flash_mla_with_kvcache -def scaled_dot_product_attention(query, key, value, is_causal=False): +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() key = key.float() value = value.float() @@ -76,6 +76,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): q[i].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q=h_q, + h_kv=h_kv, is_causal=causal, ) out[i] = O.transpose(0, 1)