mirror of
https://github.com/deepseek-ai/FlashMLA
synced 2025-06-26 18:15:54 +00:00
debug mode
This commit is contained in:
parent
f6fab1b915
commit
29de9e0c79
@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
|||||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||||
amax_diff = (x - y).abs().max().item()
|
amax_diff = (x - y).abs().max().item()
|
||||||
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
|
||||||
assert cos_diff < 1e-5
|
#assert cos_diff < 1e-5
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -131,12 +131,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
h_kv = 1
|
h_kv = 1
|
||||||
d, dv = 576, 512
|
d, dv = 576, 512
|
||||||
causal = True
|
causal = False
|
||||||
use_fp8 = False
|
use_fp8 = True
|
||||||
|
|
||||||
for b in [128]:
|
for b in [16]:
|
||||||
for s in [4096, 8192]:
|
for s in [4096]:
|
||||||
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
|
for h_q in [128]: # TP = 8, 4, 2, 1
|
||||||
for s_q in [1, 2]: # MTP = 1, 2
|
for s_q in [2]: # MTP = 1, 2
|
||||||
for varlen in [False, True]:
|
for varlen in [False]:
|
||||||
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8)
|
||||||
|
Loading…
Reference in New Issue
Block a user