diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 88ca61d..688434e 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -432,7 +432,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization - // TODO: allocate some shared memory buffers __shared__ volatile int forward_channel_nvl_tail_allocator[NUM_MAX_NVL_PEERS]; __shared__ volatile int forward_channel_nvl_tail[kNumRDMARanks][NUM_MAX_NVL_PEERS]; __shared__ volatile bool forward_channel_retired[kNumRDMARanks]; @@ -736,18 +735,24 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Try to send to different NVL ranks // TODO: shuffle the destination ranks for (int dst_nvl_rank = 0; dst_nvl_rank < NUM_MAX_NVL_PEERS; ++ dst_nvl_rank) { - if (not src_meta.is_token_in_nvl_rank(dst_nvl_rank)) + const bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + auto shifted_send_nvl_head = send_nvl_head + ((src_rdma_channel_prefix + i) * NUM_MAX_NVL_PEERS + dst_nvl_rank); + + // Skip not selected tokens + if (not is_in_dst_nvl_rank) { + kCachedMode ? (*send_nvl_head = -1) : 0; continue; + } // Allocate a slot tail and wait until that we can ensure the slot is safe to overwrite int dst_slot_idx; - auto shifted_send_nvl_head = send_nvl_head + (src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank); if (lane_id == 0) { dst_slot_idx = atomicAdd_block(const_cast(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1); + kCachedMode ? (*send_nvl_head = dst_slot_idx) : 0; while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens) cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); } - dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, 0); + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx % num_max_nvl_chunked_recv_tokens, 0); // Copy data UNROLLED_WARP_COPY(5, lane_id, hidden_int4,