change to use per_tensor

This commit is contained in:
chenhongmin.will
2025-02-26 10:17:29 +08:00
parent 4b314cd655
commit f6fab1b915
3 changed files with 12 additions and 6 deletions

View File

@@ -67,8 +67,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
if use_fp8:
nonlocal q, blocked_k, blocked_v
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((b), dtype=torch.float32)
descale_k = torch.ones((b), dtype=torch.float32)
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q_fp8 = q.to(fp8_dtype)
blocked_k_fp8 = blocked_k.to(fp8_dtype)