Fix pipeline stage edge cases

This commit is contained in:
Chenggang Zhao 2025-05-07 11:40:34 +08:00
parent bfe983c4c2
commit daec8fd2fc

View File

@ -105,10 +105,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
stage_candidates = tuple(filter(lambda s: s <= k // 128, (8, 7, 6, 5, 4, 3)))
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1)))
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, 3)
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
if best_smem_config[0] <= sm90_capacity: