From 95e81b3dd6704e279e5f4757c5b94776ac988a8d Mon Sep 17 00:00:00 2001 From: yukuai26 <93142162+yukuai26@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:55:14 +0800 Subject: [PATCH] Indivisible TMA (#90) Fix indivisible shapes for TMA multicast --------- Co-authored-by: yukuai Co-authored-by: Chenggang Zhao --- README.md | 2 +- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 35 ++++++------ deep_gemm/include/deep_gemm/scheduler.cuh | 69 +++++++++++++++-------- deep_gemm/include/deep_gemm/tma_utils.cuh | 7 +-- deep_gemm/jit_kernels/gemm.py | 13 +++-- 5 files changed, 73 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index ccc46e0..5d925da 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] Shared memory swizzling for output - [ ] Larger block size on N (up to 256) - [x] MoE scheduler with TMA multicast compatibility -- [ ] Fix TMA multicast compatibility for indivisible shapes +- [x] Fix TMA multicast compatibility for indivisible shapes - [ ] Weight gradient kernels for dense models - [ ] Weight gradient kernels for MoE models - [ ] Utility kernels for MoE models (as a pre-built CUDA library) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index c2934b8..e8370af 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -192,8 +192,11 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); // Assign TMA multicast number into A and B - constexpr int kNumTMAMulticastOnA = kIsTMAMulticastOnA ? kNumTMAMulticast : 1; - constexpr int kNumTMAMulticastOnB = kIsTMAMulticastOnA ? 1 : kNumTMAMulticast; + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant @@ -205,23 +208,18 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Issue TMA A auto& full_barrier = *full_barriers[s]; int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; - tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), - smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); - tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), - smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K)); + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K), + num_tma_multicast_a); // Issue TMA B - if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) { - // NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B), - // requiring additional checks before multicast operations. - DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast"); - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - } else { - tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), - smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); - } + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); } @@ -281,7 +279,8 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if constexpr (kNumTMAMulticast == 1) { lane_idx == 0 ? empty_barriers[s]->arrive() : void(); } else { - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); } }; diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 95dcd33..9743871 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -23,9 +23,12 @@ struct Scheduler { // For normal GEMM // Maybe not used in the masked grouped GEMM uint32_t num_blocks; + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; // For grouped GEMM int* grouped_layout; + // Only used for masked layout uint32_t curr_group_idx, curr_cumsum; @@ -43,15 +46,20 @@ struct Scheduler { } } - __device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) { - if constexpr (kGemmType == GemmType::Normal) { - return true; - } else if constexpr (kGemmType == GemmType::GroupedContiguous) { - auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); - auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); - return group_idx == peer_group_idx; - } else if constexpr (kGemmType == GemmType::GroupedMasked) { + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) { + if (num_blocks_in_group == 1) return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsTMAMulticastOnA) { + return true; + } else { + auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } } } @@ -59,23 +67,32 @@ struct Scheduler { DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); // Swizzle for better L2 usages - // TODO: unify these 2 branches + auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks; + auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks; + auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } + + // Convert to final M/N block indices if constexpr (kIsTMAMulticastOnA) { - auto num_blocks_per_group = num_m_blocks * kNum1DBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_n_block_idx = group_idx * kNum1DBlocksPerGroup; - auto num_n_blocks_in_group = min(kNum1DBlocksPerGroup, kNumNBlocks - first_n_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = in_group_idx / num_n_blocks_in_group; - n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; } else { - auto num_blocks_per_group = kNumNBlocks * kNum1DBlocksPerGroup; - auto group_idx = block_idx / num_blocks_per_group; - auto first_m_block_idx = group_idx * kNum1DBlocksPerGroup; - auto num_m_blocks_in_group = min(kNum1DBlocksPerGroup, num_m_blocks - first_m_block_idx); - auto in_group_idx = block_idx % num_blocks_per_group; - m_block_idx = first_m_block_idx + in_group_idx % num_m_blocks_in_group; - n_block_idx = in_group_idx / num_m_blocks_in_group; + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; } } @@ -102,7 +119,7 @@ struct Scheduler { if (curr_group_idx == kNumGroups) return false; - // Within current group + // Within the current group num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); auto current_m_block_cumsum = curr_cumsum + num_m_blocks; if (next_block_idx < current_m_block_cumsum * kNumNBlocks) @@ -117,6 +134,10 @@ struct Scheduler { if (next_block_idx >= num_blocks) return false; + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass) + num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); } return true; diff --git a/deep_gemm/include/deep_gemm/tma_utils.cuh b/deep_gemm/include/deep_gemm/tma_utils.cuh index 22731a6..18cdb58 100644 --- a/deep_gemm/include/deep_gemm/tma_utils.cuh +++ b/deep_gemm/include/deep_gemm/tma_utils.cuh @@ -80,15 +80,14 @@ CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], return tensor_map; } -template __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - int32_t const& crd_0, int32_t const& crd_1) { + int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) { constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); - if constexpr (kNumTMAMulticast == 1) { + if (num_tma_multicast == 1) { cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); } else if (cute::block_rank_in_cluster() == 0) { - cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, cache_hint, smem_ptr, crd_0, crd_1); + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1); } } diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index f52dc48..cb438b7 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -38,10 +38,10 @@ gemm_t::run(out, rhs_scales, nullptr, """ -def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int) -> bool: - if num_tma_multicast == 1: - return True - return (shape_dim % (block_dim * num_tma_multicast) == 0) and num_sms % num_tma_multicast == 0 +def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int, + require_divisible: bool = False) -> bool: + divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible + return divisible and num_sms % num_tma_multicast == 0 def get_swizzle_mode(block_n: int) -> int: @@ -146,9 +146,10 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, best_tma_multicast_config = (1, True) # Try to multicast on the larger block side first + # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and (not is_grouped_masked), + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, } for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): if m >= 512 and is_multicast_legal[i]: