mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Code format
This commit is contained in:
parent
5272d40aaf
commit
159ba93ab3
@ -72,22 +72,17 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
block_ms = (get_m_alignment_for_contiguous_layout(), )
|
||||||
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
|
||||||
|
|
||||||
def fix_wave_saturate(x): return num_sms if x == 0 else x
|
fix_wave_saturate = lambda x: num_sms if x == 0 else x
|
||||||
|
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
||||||
def get_num_waves(bm, bn): return (ceil_div(
|
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
||||||
ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
|
|
||||||
|
|
||||||
def get_last_wave_util(bm, bn): return fix_wave_saturate(
|
|
||||||
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
|
|
||||||
|
|
||||||
# Decide block sizes by waves
|
# Decide block sizes by waves
|
||||||
best_block_m, best_block_n = None, None
|
best_block_m, best_block_n = None, None
|
||||||
for block_m in block_ms:
|
for block_m in block_ms:
|
||||||
# NOTES: the block sizes can not be too large, so at least one dim less than 128
|
# NOTES: the block sizes cannot be too large, so at least one dim less than 128
|
||||||
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
|
||||||
success = False
|
success = False
|
||||||
num_waves, best_num_waves = get_num_waves(
|
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
||||||
block_m, block_n), get_num_waves(best_block_m, best_block_n)
|
|
||||||
if best_block_m is None or best_block_n is None:
|
if best_block_m is None or best_block_n is None:
|
||||||
success = True
|
success = True
|
||||||
elif num_waves < best_num_waves:
|
elif num_waves < best_num_waves:
|
||||||
@ -104,8 +99,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
success |= block_n == best_block_n and block_m < best_block_m
|
success |= block_n == best_block_n and block_m < best_block_m
|
||||||
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
|
||||||
success |= block_m != best_block_m and block_n > best_block_n
|
success |= block_m != best_block_m and block_n > best_block_n
|
||||||
best_block_m, best_block_n = (block_m, block_n) if success else (
|
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
|
||||||
best_block_m, best_block_n)
|
|
||||||
assert best_block_m is not None and best_block_n is not None
|
assert best_block_m is not None and best_block_n is not None
|
||||||
|
|
||||||
# Always pick the longest one
|
# Always pick the longest one
|
||||||
@ -116,8 +110,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
# Unrolling both stages and `num_former_iters` will cause large code size
|
# Unrolling both stages and `num_former_iters` will cause large code size
|
||||||
stage_candidates = (4, 3)
|
stage_candidates = (4, 3)
|
||||||
for num_stages in stage_candidates:
|
for num_stages in stage_candidates:
|
||||||
best_smem_config = get_smem_config(
|
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
|
||||||
num_stages, k, best_block_m, best_block_n)
|
|
||||||
if best_smem_config[0] <= sm90_capacity:
|
if best_smem_config[0] <= sm90_capacity:
|
||||||
best_num_stages = num_stages
|
best_num_stages = num_stages
|
||||||
break
|
break
|
||||||
@ -141,10 +134,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
|||||||
# Recompute the minimal number of SMs required
|
# Recompute the minimal number of SMs required
|
||||||
# NOTES: less L2 cache usage and less GPU frequency drop
|
# NOTES: less L2 cache usage and less GPU frequency drop
|
||||||
num_waves = get_num_waves(best_block_m, best_block_n)
|
num_waves = get_num_waves(best_block_m, best_block_n)
|
||||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) *
|
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
|
||||||
ceil_div(n, best_block_n) * num_groups, num_waves)
|
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
||||||
num_min_sms = ceil_div(
|
|
||||||
num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
|
|
||||||
assert num_min_sms <= num_sms
|
assert num_min_sms <= num_sms
|
||||||
|
|
||||||
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user