From 55bbd8caafd7533290fb3979c137c5e530345613 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 19 Jun 2025 17:15:43 +0800 Subject: [PATCH] Add impl --- csrc/kernels/internode.cu | 269 +++++++++++++++++++++++++++++++++++--- 1 file changed, 252 insertions(+), 17 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 1ed78c3..88ca61d 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -399,6 +399,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" + void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; + int rs_wr_rank = 0, ws_rr_rank = 0; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + rs_wr_buffer_ptr = buffer_ptrs[nvl_rank]; + ws_rr_buffer_ptr = buffer_ptrs[lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0]; + rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; + } + if (warp_role == WarpRole::kNVLReceivers) { + rs_wr_buffer_ptr = buffer_ptrs[target_rank]; + ws_rr_buffer_ptr = buffer_ptrs[nvl_rank]; + rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; + } + + // Allocate buffers + auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + // RDMA sender warp synchronization __shared__ volatile int rdma_send_next_token_idx; __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; @@ -406,9 +432,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization - __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; - __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; - auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + // TODO: allocate some shared memory buffers + __shared__ volatile int forward_channel_nvl_tail_allocator[NUM_MAX_NVL_PEERS]; + __shared__ volatile int forward_channel_nvl_tail[kNumRDMARanks][NUM_MAX_NVL_PEERS]; + __shared__ volatile bool forward_channel_retired[kNumRDMARanks]; + auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMARanks + 1) * 32)); }; if (warp_role == WarpRole::kRDMASender) { // Get tasks @@ -623,20 +651,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Each warp is responsible for a source RDMA rank // TODO: responsible for multiple RDMA ranks const auto src_rdma_rank = target_rank; - - // NVL buffers - auto buffer_ptr = lane_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[lane_id] : nullptr; - auto nvl_channel_x = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_src_meta = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_x_scales = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_topk_idx = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_topk_weights = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_prefix_start = AsymBuffer(buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - auto nvl_channel_prefix_end = AsymBuffer(buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - - // TODO: use NVL head/tail for coordinators - // auto nvl_channel_head = AsymBuffer(buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); - // auto nvl_channel_tail = AsymBuffer(buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + const auto num_experts_per_rank = num_experts / num_ranks; // Wait counters to arrive auto start_time = clock64(); @@ -676,6 +691,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv src_rdma_channel_prefix += src_rdma_rank == 0 ? 0 : recv_rdma_rank_prefix_sum[src_rdma_rank - 1]; EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); } + num_tokens_to_recv_from_rdma = __shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, 0); break; } @@ -687,12 +703,231 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv } } __syncwarp(); + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + int cached_nvl_channel_head = 0; + while (cached_rdma_channel_tail < num_tokens_to_recv_from_rdma) { + // Wait data arrival + start_time = clock64(); + while (lane_id == 0 and cached_rdma_channel_head == cached_rdma_channel_tail) { + cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, NVL: %d, src RDMA rank: %d, head: %d, tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + trap(); + } + } + cached_rdma_channel_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, 0); + + // Iterate over every token from the RDMA buffer + for (int i = cached_rdma_channel_head; i < cached_rdma_channel_tail; ++ i) { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + auto src_meta = ld_nc_global(reinterpret_cast(static_cast(shifted) + hidden_bytes)); + + // TODO: load into shared memory (only read once) + + // Try to send to different NVL ranks + // TODO: shuffle the destination ranks + for (int dst_nvl_rank = 0; dst_nvl_rank < NUM_MAX_NVL_PEERS; ++ dst_nvl_rank) { + if (not src_meta.is_token_in_nvl_rank(dst_nvl_rank)) + continue; + + // Allocate a slot tail and wait until that we can ensure the slot is safe to overwrite + int dst_slot_idx; + auto shifted_send_nvl_head = send_nvl_head + (src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank); + if (lane_id == 0) { + dst_slot_idx = atomicAdd_block(const_cast(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1); + while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens) + cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); + } + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, 0); + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = static_cast(shifted) + hidden_int4; + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); + shifted = static_cast(shifted) + 1; + + // Copy `x_scales` + UNROLLED_WARP_COPY(1, lane_id, num_scales, + nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = static_cast(shifted) + num_scales; + + // Copy `topk_idx` and `topk_weights` + // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted + if (lane_id < num_topk) { + // Read + auto idx_value = ld_nc_global(static_cast(shifted) + lane_id); + shifted = static_cast(shifted) + num_topk; + auto weight_value = ld_nc_global(static_cast(shifted) + lane_id); + + // Transform and write + const auto dst_rank_expert_begin = (rdma_rank + NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_experts_per_rank; + const auto dst_rank_expert_end = dst_rank_expert_begin + num_experts_per_rank; + idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; + st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); + } + + // Move tail index + __syncwarp(); + if (lane_id == 0) + st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1); + } + } + + // Update remote head + // TODO: this part should be moved into the coordinator warp to overlap with TMA + if (lane_id == 0) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), + cached_rdma_channel_tail - cached_rdma_channel_head, + translate_dst_rdma_rank(lane_id, nvl_rank), + channel_id + num_channels, src_rdma_rank == rdma_rank); + cached_rdma_channel_head = cached_rdma_channel_tail; + } + __syncwarp(); + } } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Forward warp coordinator + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + // Extra warps for forwarder coordinator should exit directly if (target_rank > 0) return; + + // Clean shared memory + for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) + forward_channel_nvl_tail[i / NUM_MAX_NVL_PEERS][i % NUM_MAX_NVL_PEERS] = 0; + if (lane_id < NUM_MAX_NVL_PEERS) + forward_channel_nvl_tail_allocator[lane_id] = 0; + if (lane_id < kNumRDMARanks) + forward_channel_retired[lane_id] = 0; + sync_forwarder_smem(); + + // Update minimum tail + int last_tail = 0, dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + while (true) { + // Find minimum tail + int min_tail = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) if (not forward_channel_retired[i]) + min_tail = min(min_tail, forward_channel_nvl_tail[i][lane_id]); + if (__all_sync(0xffffffff, min_tail == std::numeric_limits::max())) + break; + + // Update remote tail + if (min_tail != std::numeric_limits::max() and min_tail >= last_tail + num_max_nvl_chunked_send_tokens) + st_release_sys_global(nvl_channel_tail.buffer(), min_tail); + } } else { // NVL consumers + // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) + int src_nvl_rank = target_rank, total_offset = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) + total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; + + // Receive channel offsets + int start_offset = 0, end_offset = 0, num_tokens_to_recv; + auto start_time = clock64(); + while (lane_id < kNumRDMARanks) { + start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); + end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); + if (start_offset < 0 and end_offset < 0) { + start_offset = -start_offset - 1, end_offset = -end_offset - 1; + total_offset += start_offset; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); + trap(); + } + } + num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + + // Save for combine usage + if (lane_id < kNumRDMARanks and not kCachedMode) + recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; + __syncwarp(); + + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // Check channel status by lane 0 + start_time = clock64(); + while (lane_id == 0) { + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) + break; + cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); + trap(); + } + } + + // Sync queue tail + cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { + int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; + auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); + int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); + (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + recv_x + recv_token_idx * hidden_int4, + nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, + ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0 and not kCachedMode) + st_na_global(recv_src_meta + recv_token_idx, meta); + + // Copy scales + UNROLLED_WARP_COPY(1, lane_id, num_scales, + recv_x_scales + recv_token_idx * num_scales, + nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, + ld_nc_global, st_na_global); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + auto recv_idx = recv_token_idx * num_topk + lane_id; + auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; + st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); + st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); + } + } + + // Move queue + __syncwarp(); + if (lane_id == 0) + st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); + } } }