Merge pull request #79 from yizhang2077/lru-cache-opt

Add lru-cache for get_best_configs to avoid repeated calculation
This commit is contained in:
Chenggang Zhao 2025-04-07 09:31:30 +08:00 committed by GitHub
commit b0868c9014
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
import math
import torch
from functools import lru_cache
from typing import Tuple
from .tuner import jit_tuner
@ -66,7 +67,7 @@ def get_smem_size(num_stages: int, k: int, block_m: int, block_n: int, block_k:
smem_size += smem_barrier
return smem_size
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False) -> Tuple[int, int, int, int, Tuple[int, bool], int]:
if not is_grouped_contiguous: