From 159ba93ab39dd0893ffbb5afafbf2d110323f842 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 7 May 2025 10:13:19 +0800 Subject: [PATCH] Code format --- deep_gemm/jit_kernels/gemm.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index fa15418..d8023fe 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -72,22 +72,17 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (get_m_alignment_for_contiguous_layout(), ) block_ns = tuple(range(16, 129, 8)) + (144, 160, ) - def fix_wave_saturate(x): return num_sms if x == 0 else x - - def get_num_waves(bm, bn): return (ceil_div( - ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) - - def get_last_wave_util(bm, bn): return fix_wave_saturate( - (ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) + fix_wave_saturate = lambda x: num_sms if x == 0 else x + get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None) + get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms) # Decide block sizes by waves best_block_m, best_block_n = None, None for block_m in block_ms: - # NOTES: the block sizes can not be too large, so at least one dim less than 128 + # NOTES: the block sizes cannot be too large, so at least one dim less than 128 for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns): success = False - num_waves, best_num_waves = get_num_waves( - block_m, block_n), get_num_waves(best_block_m, best_block_n) + num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n) if best_block_m is None or best_block_n is None: success = True elif num_waves < best_num_waves: @@ -104,8 +99,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, success |= block_n == best_block_n and block_m < best_block_m # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better success |= block_m != best_block_m and block_n > best_block_n - best_block_m, best_block_n = (block_m, block_n) if success else ( - best_block_m, best_block_n) + best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n) assert best_block_m is not None and best_block_n is not None # Always pick the longest one @@ -116,8 +110,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = (4, 3) for num_stages in stage_candidates: - best_smem_config = get_smem_config( - num_stages, k, best_block_m, best_block_n) + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -141,10 +134,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * - ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div( - num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config