add lru-cache to avoid repeated calculation

This commit is contained in:
Yi Zhang 2025-04-04 12:44:26 +08:00 committed by GitHub
parent c187c23ba8
commit 776bd0cccc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import math
import torch
from functools import lru_cache
from typing import Tuple
from .tuner import jit_tuner
@ -66,7 +67,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_size += smem_barrier
return smem_size
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
if not is_grouped_contiguous: