From e1e9fa98f80f34c3b155fd483c38227abb5f400d Mon Sep 17 00:00:00 2001 From: ljss <450993438@qq.com> Date: Tue, 25 Feb 2025 09:18:11 +0800 Subject: [PATCH] Style fix --- tests/test_flash_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index e676fa7..0abe9d2 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = ( float("nan") ) blocked_v = blocked_k[..., :dv]