mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Remove memory fence in NVLink barrier.
This commit is contained in:
parent
a15faa9ff0
commit
f1d7a7c89f
@ -41,6 +41,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
for (int i = 0; i < num_experts_per_rank; ++ i)
|
||||
per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i];
|
||||
}
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Wait for all ranks to be finished
|
||||
|
@ -446,7 +446,6 @@ barrier_block(int** barrier_signal_ptrs, int rank) {
|
||||
// Add self-ranks, sub other ranks
|
||||
if (thread_id < kNumRanks) {
|
||||
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
|
||||
memory_fence();
|
||||
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
|
||||
}
|
||||
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
|
||||
|
Loading…
Reference in New Issue
Block a user