diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 0abe9d2..67c9d93 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -127,7 +127,7 @@ def main(torch_dtype): causal = True for b in [128]: - for s in [4096, 8192]: + for s in [4096, 8192, 16384]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: