Style fix

This commit is contained in:
ljss 2025-02-25 09:18:11 +08:00
parent a3b74b8574
commit e1e9fa98f8

View File

@ -60,7 +60,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
).view(b, max_seqlen_pad // block_size) ).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b): for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan") float("nan")
) )
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]