From 3071a2cd37a0f4c993ba420d45f57276c83802b4 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Thu, 26 Jun 2025 10:03:37 +0800 Subject: [PATCH] Move `__syncthread` and fence into barrier. --- csrc/kernels/internode.cu | 9 ++------- csrc/kernels/intranode.cu | 14 +++----------- csrc/kernels/utils.cuh | 5 ++++- 3 files changed, 9 insertions(+), 19 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index da1d203..7e22bd9 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -99,7 +99,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); @@ -199,8 +199,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in for (int i = 0; i < num_nvl_experts; ++ i) nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; } - memory_fence(); - __syncthreads(); barrier_block(barrier_signal_ptrs, nvl_rank); // Reduce the number of tokens per rank/expert @@ -227,7 +225,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in } // Finally barrier - __syncthreads(); if (thread_id == 32) nvshmem_sync_with_same_gpu_idx(rdma_team); barrier_block(barrier_signal_ptrs, nvl_rank); @@ -1039,15 +1036,13 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in nvshmem_sync_with_same_gpu_idx(rdma_team); } else if (sm_id == 1) { // Barrier for NVL - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); // Clean auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; - memory_fence(); - __syncthreads(); // Barrier again barrier_block(barrier_signal_ptrs, nvl_rank); diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 23e02ba..58b003d 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -21,7 +21,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, if (sm_id == 0) { // Barrier first - barrier_block(barrier_signal_ptrs, rank); + barrier_block(barrier_signal_ptrs, rank); int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { @@ -41,8 +41,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, for (int i = 0; i < num_experts_per_rank; ++ i) per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; } - memory_fence(); - __syncthreads(); // Wait for all ranks to be finished barrier_block(barrier_signal_ptrs, rank); @@ -81,8 +79,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, local_per_expert_buffer[i] = 0; // Barrier - memory_fence(); - __syncthreads(); barrier_block(barrier_signal_ptrs, rank); } else { int dst_rank = sm_id - 1; @@ -138,7 +134,7 @@ __global__ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** barrier_signal_ptrs, int rank) { // A simplified version for cached handles - barrier_block(barrier_signal_ptrs, rank); + barrier_block(barrier_signal_ptrs, rank); // Copy and clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); @@ -149,8 +145,6 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[kNumRanks * kNumRanks + i] = 0; - memory_fence(); - __syncthreads(); // Barrier after cleaning barrier_block(barrier_signal_ptrs, rank); @@ -521,7 +515,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int const auto sm_id = static_cast(blockIdx.x); if (sm_id == 0) { // Barrier before cleaning - barrier_block(barrier_signal_ptrs, rank); + barrier_block(barrier_signal_ptrs, rank); // Clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); @@ -529,8 +523,6 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; - memory_fence(); - __syncthreads(); // Barrier after cleaning barrier_block(barrier_signal_ptrs, rank); diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 9b24f04..de8853f 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -438,11 +438,14 @@ __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value } } -template +template __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, int rank) { auto thread_id = static_cast(threadIdx.x); + if (fence) + memory_fence(); + // Add self-ranks, sub other ranks if (thread_id < kNumRanks) { atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);