From 25db8de3454800bd40b93860b99c303570d09632 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 25 Mar 2025 17:34:06 +0800 Subject: [PATCH] Better performance --- deep_gemm/jit_kernels/gemm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 31d1a2e..65b44ff 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -1,3 +1,4 @@ +import math import torch from typing import Tuple @@ -90,7 +91,11 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced best_num_stages, best_smem_size, sm90_capacity = None, None, 232448 - for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4): + stage_candidates = (8, 7, 6, 5, 4) + if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4: + # Unrolling both stages and `num_former_iters` will cause large code size + stage_candidates = (4, ) + for num_stages in stage_candidates: best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n) if best_smem_size <= sm90_capacity: best_num_stages = num_stages