mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-06-26 23:15:49 +00:00
Several code lints x2
This commit is contained in:
@@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr,
|
||||
"""
|
||||
|
||||
|
||||
def is_tma_multicast_legal(shape_dim: int, multicast_block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
|
||||
if num_tma_multicast == 1:
|
||||
return True
|
||||
return shape_dim % multicast_block_dim == 0 and num_sms % num_tma_multicast == 0
|
||||
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
|
||||
require_divisible: bool = False) -> bool:
|
||||
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
|
||||
return divisible and num_sms % num_tma_multicast == 0
|
||||
|
||||
|
||||
def get_swizzle_mode(block_n: int) -> int:
|
||||
@@ -146,10 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
|
||||
best_tma_multicast_config = (1, True)
|
||||
|
||||
# Try to multicast on the larger block side first
|
||||
# NOTES: Currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the n-direction to be even.
|
||||
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
|
||||
is_multicast_legal = {
|
||||
'A': is_tma_multicast_legal(n, best_block_n * (2 if is_grouped_masked else 1), 2, num_sms),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
|
||||
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
|
||||
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
|
||||
}
|
||||
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
|
||||
if m >= 512 and is_multicast_legal[i]:
|
||||
|
||||
Reference in New Issue
Block a user