From ef644a56e0976cb74d2a57efabded73dfc05deeb Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 08:13:56 +0800 Subject: [PATCH] update ut --- tests/test_flash_mla.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 37cbb10..f700864 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -61,22 +61,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) - descale_q, descale_k = None, None - if use_fp8: - fp8_dtype = torch.float8_e4m3fn - descale_q = torch.ones((b), dtype=torch.float32) - descale_k = torch.ones((b), dtype=torch.float32) - - q_fp8 = q.to(fp8_dtype) - blocked_k_fp8 = blocked_k.to(fp8_dtype) - blocked_v_fp8 = blocked_v.to(fp8_dtype) - q = q_fp8.to(q.dtype) - blocked_k = blocked_k_fp8.to(blocked_k.dtype) - blocked_v = blocked_v_fp8.to(blocked_v.dtype) + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None + + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((b), dtype=torch.float32) + descale_k = torch.ones((b), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + + q = q_fp8.to(q.dtype) * descale_q + blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k + blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k + return q_fp8, blocked_k_fp8, descale_q, descale_k + + + q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input() def flash_mla(): - q_ = q_fp8 if use_fp8 else q - blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k + q_ = q; blocked_k_ = blocked_k + if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8 return flash_mla_with_kvcache( q_, blocked_k_, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal,