diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index cec83fb..14854e0 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -1,5 +1,6 @@ import math import torch +from functools import lru_cache from typing import Tuple from .tuner import jit_tuner @@ -66,7 +67,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k: smem_size += smem_barrier return smem_size - +@lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]: if not is_grouped_contiguous: