Several code lints x2

This commit is contained in:
Chenggang Zhao
2025-04-22 17:24:02 +08:00
parent 902208a17e
commit f4014953ad
3 changed files with 11 additions and 8 deletions

View File

@@ -16,7 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] Shared memory swizzling for output
- [ ] Larger block size on N (up to 256)
- [x] MoE scheduler with TMA multicast compatibility
- [ ] Fix TMA multicast compatibility for indivisible shapes
- [x] Fix TMA multicast compatibility for indivisible shapes
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)

View File

@@ -83,6 +83,8 @@ struct Scheduler {
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
@@ -93,6 +95,7 @@ struct Scheduler {
}
}
// Convert to final M/N block indices
if constexpr (kIsTMAMulticastOnA) {
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;

View File

@@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr,
"""
def is_tma_multicast_legal(shape_dim: int, multicast_block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
return True
return shape_dim % multicast_block_dim == 0 and num_sms % num_tma_multicast == 0
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
require_divisible: bool = False) -> bool:
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
return divisible and num_sms % num_tma_multicast == 0
def get_swizzle_mode(block_n: int) -> int:
@@ -146,10 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
# NOTES: Currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the n-direction to be even.
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n * (2 if is_grouped_masked else 1), 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]: