diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 14854e0..cbd6fc2 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -67,6 +67,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_size += smem_barrier return smem_size + @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: