From 48a5f071beca106a636e82f502ce7cd4d9201220 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 10:01:15 +0800 Subject: [PATCH] Clean up config heuristics --- deep_gemm/jit_kernels/gemm.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 9c6e0b2..a57d348 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -73,14 +73,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: - # TODO: for some cases, smaller M block is better, add them into tuning space - # block_ms = (64 if m <= 64 else 128, ) - if m <= 64: - block_ms = (64, ) - elif m <= 128: - block_ms = (64, 128, ) - else: - block_ms = (64, 128, 256, ) + block_ms = (64, 128, 256) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) @@ -103,7 +96,14 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) - success = util > best_util or (util == best_util and (block_m > best_block_m or (block_m == best_block_m and block_n < best_block_n))) + success = util > best_util + if util == best_util: + # Case 1: same `block_m`, smaller `block_n` (wasted) + success |= block_m == best_block_m and block_n < best_block_n + # Case 2: same `block_n`, smaller `block_m` (wasted) + success |= block_n == best_block_n and block_m < best_block_m + # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better + success |= block_m != best_block_m and block_n > best_block_n best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None