From d2369e8f30e48d95efb21087cdc9416297d0a5cd Mon Sep 17 00:00:00 2001 From: yukuai Date: Tue, 22 Apr 2025 15:19:29 +0800 Subject: [PATCH] fix typo2 --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 2 +- deep_gemm/include/deep_gemm/scheduler.cuh | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 5fb5b66..c9b3300 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -292,7 +292,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout, if constexpr (kNumTMAMulticast == 1) { lane_idx == 0 ? empty_barriers[s]->arrive() : void(); } else { - auto target_cta_idx = scheduler.num_blocks_in_group != 1 ? lane_idx : cute::block_rank_in_cluster(); + auto target_cta_idx = scheduler.is_block_in_complete_cluster(m_block_idx, n_block_idx) ? lane_idx : cute::block_rank_in_cluster(); lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta_idx) : void(); } }; diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 6a595b8..477c6eb 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -44,6 +44,13 @@ struct Scheduler { } } + __device__ __forceinline__ bool is_block_in_complete_cluster(const uint32_t& m_block_idx, uint32_t& n_block_idx) { + if (num_blocks_in_group == 1 and n_block_idx == kNumNBlocks - 1 and m_block_idx == num_aligned_m_blocks - 1 + and num_aligned_m_blocks % 2 == 1 and kNumNBlocks % 2 == 1) + return false; + return true; + } + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) { if (num_blocks_in_group == 1) return false;