mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix send heads
This commit is contained in:
@@ -432,7 +432,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); };
|
||||
|
||||
// Forward warp synchronization
|
||||
// TODO: allocate some shared memory buffers
|
||||
__shared__ volatile int forward_channel_nvl_tail_allocator[NUM_MAX_NVL_PEERS];
|
||||
__shared__ volatile int forward_channel_nvl_tail[kNumRDMARanks][NUM_MAX_NVL_PEERS];
|
||||
__shared__ volatile bool forward_channel_retired[kNumRDMARanks];
|
||||
@@ -736,18 +735,24 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Try to send to different NVL ranks
|
||||
// TODO: shuffle the destination ranks
|
||||
for (int dst_nvl_rank = 0; dst_nvl_rank < NUM_MAX_NVL_PEERS; ++ dst_nvl_rank) {
|
||||
if (not src_meta.is_token_in_nvl_rank(dst_nvl_rank))
|
||||
const bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
|
||||
auto shifted_send_nvl_head = send_nvl_head + ((src_rdma_channel_prefix + i) * NUM_MAX_NVL_PEERS + dst_nvl_rank);
|
||||
|
||||
// Skip not selected tokens
|
||||
if (not is_in_dst_nvl_rank) {
|
||||
kCachedMode ? (*send_nvl_head = -1) : 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Allocate a slot tail and wait until that we can ensure the slot is safe to overwrite
|
||||
int dst_slot_idx;
|
||||
auto shifted_send_nvl_head = send_nvl_head + (src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank);
|
||||
if (lane_id == 0) {
|
||||
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
|
||||
kCachedMode ? (*send_nvl_head = dst_slot_idx) : 0;
|
||||
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
|
||||
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
|
||||
}
|
||||
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, 0);
|
||||
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx % num_max_nvl_chunked_recv_tokens, 0);
|
||||
|
||||
// Copy data
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
|
||||
|
||||
Reference in New Issue
Block a user