Indivisible TMA (#90)

Fix indivisible shapes for TMA multicast

---------

Co-authored-by: yukuai <yukuai@deepseek.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
This commit is contained in:
yukuai26 2025-04-23 14:55:14 +08:00 committed by GitHub
parent 891f35adf5
commit 95e81b3dd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 53 deletions

View File

@ -16,7 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] Shared memory swizzling for output
- [ ] Larger block size on N (up to 256)
- [x] MoE scheduler with TMA multicast compatibility
- [ ] Fix TMA multicast compatibility for indivisible shapes
- [x] Fix TMA multicast compatibility for indivisible shapes
- [ ] Weight gradient kernels for dense models
- [ ] Weight gradient kernels for MoE models
- [ ] Utility kernels for MoE models (as a pre-built CUDA library)

View File

@ -192,8 +192,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1;
constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast;
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
@ -205,23 +208,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy<kNumTMAMulticastOnA>(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
tma_copy<kNumTMAMulticastOnA>(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA B
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
// requiring additional checks before multicast operations.
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
} else {
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
}
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}
@ -281,7 +279,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void();
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};

View File

@ -23,9 +23,12 @@ struct Scheduler {
// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
@ -43,15 +46,20 @@ struct Scheduler {
}
}
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
if constexpr (kIsTMAMulticastOnA) {
return true;
} else {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
}
}
}
@ -59,23 +67,32 @@ struct Scheduler {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
// TODO: unify these 2 branches
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
} else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
// Convert to final M/N block indices
if constexpr (kIsTMAMulticastOnA) {
auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
} else {
auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup;
auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group;
n_block_idx = in_group_idx / num_m_blocks_in_group;
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}
@ -102,7 +119,7 @@ struct Scheduler {
if (curr_group_idx == kNumGroups)
return false;
// Within current group
// Within the current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
@ -117,6 +134,10 @@ struct Scheduler {
if (next_block_idx >= num_blocks)
return false;
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;

View File

@ -80,15 +80,14 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2],
return tensor_map;
}
template <uint32_t kNumTMAMulticast = 1>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1) {
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if constexpr (kNumTMAMulticast == 1) {
if (num_tma_multicast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}

View File

@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr,
"""
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool:
if num_tma_multicast == 1:
return True
return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
require_divisible: bool = False) -> bool:
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
return divisible and num_sms % num_tma_multicast == 0
def get_swizzle_mode(block_n: int) -> int:
@ -146,9 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked),
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]: