mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
change to use per_tensor
This commit is contained in:
@@ -67,8 +67,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
if use_fp8:
|
||||
nonlocal q, blocked_k, blocked_v
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((b), dtype=torch.float32)
|
||||
descale_k = torch.ones((b), dtype=torch.float32)
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q_fp8 = q.to(fp8_dtype)
|
||||
blocked_k_fp8 = blocked_k.to(fp8_dtype)
|
||||
|
||||
Reference in New Issue
Block a user