Clean up config heuristics

This commit is contained in:
Chenggang Zhao 2025-04-09 10:01:15 +08:00
parent ce65d5e33c
commit 48a5f071be

View File

@ -73,14 +73,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], int]:
if not is_grouped_contiguous:
# TODO: for some cases, smaller M block is better, add them into tuning space
# block_ms = (64 if m <= 64 else 128, )
if m <= 64:
block_ms = (64, )
elif m <= 128:
block_ms = (64, 128, )
else:
block_ms = (64, 128, 256, )
block_ms = (64, 128, 256)
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
@ -103,7 +96,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n)))
success = util > best_util
if util == best_util:
# Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n
# Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
success |= block_m != best_block_m and block_n > best_block_n
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None