From 59884211ea81d06d18b1915a11cff669bb9e2dfe Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 22 Apr 2025 17:37:34 +0800 Subject: [PATCH] Simplify --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 7 +++---- deep_gemm/include/deep_gemm/scheduler.cuh | 10 ---------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 038ff3e..aa39c16 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -276,11 +276,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, // Empty barrier arrival auto empty_barrier_arrive = [&](int s) { - if constexpr (kNumTMAMulticast == 1) { - lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + if (kNumTMAMulticast == 1 or not scheduler.is_tma_multicast_valid(m_block_idx)) { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive() : void(); } else { - auto target_cta_idx = scheduler.is_block_in_complete_cluster(m_block_idx, n_block_idx) ? lane_idx : cute::block_rank_in_cluster(); - lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta_idx) : void(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); } }; diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 6a0b9dc..d41e72c 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -45,16 +45,6 @@ struct Scheduler { } } - __device__ __forceinline__ bool is_block_in_complete_cluster(const uint32_t& m_block_idx, const uint32_t& n_block_idx) { - // NOTES: For the case where the total number is an odd number of blocks, the last block requires special barrier processing. - // Here, we need each cluster to have exactly two blocks. - DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); - if (kNumTMAMulticast == 2 and num_blocks_in_group == 1 and n_block_idx == kNumNBlocks - 1 and m_block_idx == num_aligned_m_blocks - 1 - and num_aligned_m_blocks % 2 == 1 and kNumNBlocks % 2 == 1) - return false; - return true; - } - __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) { if (num_blocks_in_group == 1) return false;