Move __syncthread and fence into barrier.

This commit is contained in:
Shangyan Zhou 2025-06-26 10:03:37 +08:00
parent f1d7a7c89f
commit 3071a2cd37
3 changed files with 9 additions and 19 deletions

View File

@ -99,7 +99,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
barrier_block<NUM_MAX_NVL_PEERS, false>(barrier_signal_ptrs, nvl_rank);
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
@ -199,8 +199,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
for (int i = 0; i < num_nvl_experts; ++ i)
nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i];
}
memory_fence();
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Reduce the number of tokens per rank/expert
@ -227,7 +225,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
}
// Finally barrier
__syncthreads();
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
@ -1039,15 +1036,13 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
barrier_block<NUM_MAX_NVL_PEERS, false>(barrier_signal_ptrs, nvl_rank);
// Clean
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
#pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
memory_fence();
__syncthreads();
// Barrier again
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);

View File

@ -21,7 +21,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
if (sm_id == 0) {
// Barrier first
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, false>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
@ -41,8 +41,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
for (int i = 0; i < num_experts_per_rank; ++ i)
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
}
memory_fence();
__syncthreads();
// Wait for all ranks to be finished
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
@ -81,8 +79,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
local_per_expert_buffer[i] = 0;
// Barrier
memory_fence();
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
int dst_rank = sm_id - 1;
@ -138,7 +134,7 @@ __global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
// A simplified version for cached handles
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, false>(barrier_signal_ptrs, rank);
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@ -149,8 +145,6 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[kNumRanks * kNumRanks + i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
@ -521,7 +515,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
barrier_block<kNumRanks, false>(barrier_signal_ptrs, rank);
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@ -529,8 +523,6 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
#pragma unroll
for (int i = thread_id; i < num_memset_int; i += num_threads)
ptr[i] = 0;
memory_fence();
__syncthreads();
// Barrier after cleaning
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);

View File

@ -438,11 +438,14 @@ __forceinline__ __device__ out_dtype_t extract_required_scale_format(float value
}
}
template <int kNumRanks>
template <int kNumRanks, bool fence=true>
__forceinline__ __device__ void
barrier_block(int** barrier_signal_ptrs, int rank) {
auto thread_id = static_cast<int>(threadIdx.x);
if (fence)
memory_fence();
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);