From 4e055a6142143dc8cdd70e918fef9eb6dbef3949 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 18:59:02 +0800 Subject: [PATCH] reorg ut --- tests/test_flash_mla.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 03c9037..b840a97 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -63,8 +63,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) - + init_dtype = q.dtype + def prepare_fp8_input(): q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None @@ -78,33 +79,36 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = blocked_k_fp8 = blocked_k.to(fp8_dtype) blocked_v_fp8 = blocked_v.to(fp8_dtype) - q = q_fp8.to(q.dtype) * descale_q - blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k - blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k - return q_fp8, blocked_k_fp8, descale_q, descale_k + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k - q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input() - + if use_fp8: + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): - q_ = q; blocked_k_ = blocked_k - if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8 return flash_mla_with_kvcache( - q_, blocked_k_, block_table, cache_seqlens, dv, + q, blocked_k, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, descale_q=descale_q, descale_k=descale_k, ) def ref_mla(): + if use_fp8: + q_ = (q.to(torch.float) * descale_q).to(init_dtype) + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q=h_q, h_kv=h_kv, is_causal=causal,