Larger block N candidates

This commit is contained in:
Chenggang Zhao 2025-04-09 10:11:43 +08:00
parent 48a5f071be
commit a6524d411a

View File

@ -76,7 +76,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
block_ms = (64, 128, 256)
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8))
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
@ -85,7 +85,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
# Decide block sizes by waves
best_block_m, best_block_n = None, None
for block_m in block_ms:
for block_n in block_ns:
# NOTES: the block sizes can not be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None: