mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
fix mma0
This commit is contained in:
@@ -27,13 +27,17 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
|
||||
return attn_weight @ value, lse
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
amax_diff = (x - y).abs().max().item()
|
||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-3
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -111,7 +115,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 =
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(out_flash, out_torch, "out", use_fp8)
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
|
||||
Reference in New Issue
Block a user