update ut

This commit is contained in:
chenhongmin.will 2025-02-26 08:13:56 +08:00
parent 870418802a
commit ef644a56e0

View File

@ -61,22 +61,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
descale_q, descale_k = None, None def prepare_fp8_input():
if use_fp8: q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((b), dtype=torch.float32) if use_fp8:
descale_k = torch.ones((b), dtype=torch.float32) nonlocal q, blocked_k, blocked_v
fp8_dtype = torch.float8_e4m3fn
q_fp8 = q.to(fp8_dtype) descale_q = torch.ones((b), dtype=torch.float32)
blocked_k_fp8 = blocked_k.to(fp8_dtype) descale_k = torch.ones((b), dtype=torch.float32)
blocked_v_fp8 = blocked_v.to(fp8_dtype)
q = q_fp8.to(q.dtype) q_fp8 = q.to(fp8_dtype)
blocked_k = blocked_k_fp8.to(blocked_k.dtype) blocked_k_fp8 = blocked_k.to(fp8_dtype)
blocked_v = blocked_v_fp8.to(blocked_v.dtype) blocked_v_fp8 = blocked_v.to(fp8_dtype)
q = q_fp8.to(q.dtype) * descale_q
blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k
blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k
return q_fp8, blocked_k_fp8, descale_q, descale_k
q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input()
def flash_mla(): def flash_mla():
q_ = q_fp8 if use_fp8 else q q_ = q; blocked_k_ = blocked_k
blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8
return flash_mla_with_kvcache( return flash_mla_with_kvcache(
q_, blocked_k_, block_table, cache_seqlens, dv, q_, blocked_k_, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal, tile_scheduler_metadata, num_splits, causal=causal,