diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e676fa7..0abe9d2 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( float("nan") ) blocked_v = blocked_k[..., :dv]