diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 1860b43..859eb34 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -79,10 +79,12 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, elif num_waves < best_num_waves: success = True elif num_waves == best_num_waves: + div_n = bool(128 % block_n) + best_div_n = bool(128 % best_block_n) # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_n >= best_block_n and block_m <= best_block_m)) + success = util > best_util or (util == best_util and (block_m > best_block_m or block_m == best_block_m and (div_n < best_div_n or div_n == best_div_n and block_n < best_block_n))) best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None