diff --git a/README.md b/README.md index 6abda62..dab1f05 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,6 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Skip useless computation on M -- [ ] Share pipeline stages between scheduled blocks -- [ ] TMA store pipeline - [ ] NVRTC as a faster compiler - [ ] Sanitizer for testing - [ ] Weight gradient kernels for dense models diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index cb438b7..c6fd29d 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -130,7 +130,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_config, sm90_capacity = None, None, 232448 - stage_candidates = (8, 7, 6, 5, 4, 3) + stage_candidates = tuple(filter(lambda s: s <= k // 128, (8, 7, 6, 5, 4, 3))) if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3)