mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-04-29 22:40:22 +00:00
Better performance
This commit is contained in:
parent
1999d553e5
commit
25db8de345
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@ -90,7 +91,11 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
# Always pick the longest one
|
# Always pick the longest one
|
||||||
# NOTES: for double B scales, the best number of stages may be reduced
|
# NOTES: for double B scales, the best number of stages may be reduced
|
||||||
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
best_num_stages, best_smem_size, sm90_capacity = None, None, 232448
|
||||||
for num_stages in (6, 5, 4) if 128 % best_block_n != 0 else (8, 7, 6, 5, 4):
|
stage_candidates = (8, 7, 6, 5, 4)
|
||||||
|
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
|
||||||
|
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||||
|
stage_candidates = (4, )
|
||||||
|
for num_stages in stage_candidates:
|
||||||
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
best_smem_size = get_smem_size(num_stages, k, best_block_m, best_block_n)
|
||||||
if best_smem_size <= sm90_capacity:
|
if best_smem_size <= sm90_capacity:
|
||||||
best_num_stages = num_stages
|
best_num_stages = num_stages
|
||||||
|
Loading…
Reference in New Issue
Block a user