Merge remote-tracking branch 'upstream/main' into nvrtc

This commit is contained in:
Zihua Wu
2025-04-23 00:17:28 -07:00
5 changed files with 73 additions and 53 deletions

View File

@@ -9,10 +9,10 @@ from .tuner import jit_tuner
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
return True
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
require_divisible: bool = False) -> bool:
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
return divisible and num_sms % num_tma_multicast == 0
def get_swizzle_mode(block_n: int) -> int:
@@ -126,9 +126,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]: