From 49b9084268d858399eea476fb825242d1140dca8 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 20 Jun 2025 10:57:56 +0800 Subject: [PATCH] Fix several bugs --- csrc/kernels/internode.cu | 45 +++++++++++++++------------------------ 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 688434e..5327f8f 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -401,18 +401,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // NVL buffer layouts // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" - void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; int rs_wr_rank = 0, ws_rr_rank = 0; - if (warp_role == WarpRole::kRDMAAndNVLForwarder) { - rs_wr_buffer_ptr = buffer_ptrs[nvl_rank]; - ws_rr_buffer_ptr = buffer_ptrs[lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0]; - rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; - } - if (warp_role == WarpRole::kNVLReceivers) { - rs_wr_buffer_ptr = buffer_ptrs[target_rank]; - ws_rr_buffer_ptr = buffer_ptrs[nvl_rank]; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) + rs_wr_rank = nvl_rank, ws_rr_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + if (warp_role == WarpRole::kNVLReceivers) rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; - } + auto rs_wr_buffer_ptr = buffer_ptrs[rs_wr_rank]; + auto ws_rr_buffer_ptr = buffer_ptrs[ws_rr_rank]; // Allocate buffers auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); @@ -665,17 +660,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; // Notify NVL ranks - int *dst_start_ptr, *dst_end_ptr; - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) { - auto start_ptr = nvl_channel_prefix_start.buffer_by_sync(i) + src_rdma_rank; - auto end_ptr = nvl_channel_prefix_end.buffer_by_sync(i) + src_rdma_rank; - dst_start_ptr = i == lane_id ? start_ptr : dst_start_ptr; - dst_end_ptr = i == lane_id ? end_ptr : dst_end_ptr; - } if (lane_id < NUM_MAX_NVL_PEERS) { - st_relaxed_sys_global(dst_start_ptr, -start_sum - 1); - st_relaxed_sys_global(dst_end_ptr, -end_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + src_rdma_rank, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + src_rdma_rank, -end_sum - 1); EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); } __syncwarp(); @@ -701,14 +688,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv trap(); } } - __syncwarp(); // Wait shared memory to be cleaned sync_forwarder_smem(); // Forward tokens from RDMA buffer int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; - int cached_nvl_channel_head = 0; + int cached_nvl_channel_head = 0; // This value is used only by lane 0 while (cached_rdma_channel_tail < num_tokens_to_recv_from_rdma) { // Wait data arrival start_time = clock64(); @@ -716,9 +702,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, NVL: %d, src RDMA rank: %d, head: %d, tail: %d, expected: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + channel_id, rdma_rank, nvl_rank, src_rdma_rank, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); trap(); } } @@ -740,7 +726,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Skip not selected tokens if (not is_in_dst_nvl_rank) { - kCachedMode ? (*send_nvl_head = -1) : 0; + kCachedMode ? (*shifted_send_nvl_head = -1) : 0; continue; } @@ -748,7 +734,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int dst_slot_idx; if (lane_id == 0) { dst_slot_idx = atomicAdd_block(const_cast(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1); - kCachedMode ? (*send_nvl_head = dst_slot_idx) : 0; + kCachedMode ? (*shifted_send_nvl_head = dst_slot_idx) : 0; while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens) cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); } @@ -794,6 +780,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv __syncwarp(); if (lane_id == 0) st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1); + __syncwarp(); } } @@ -832,13 +819,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int min_tail = std::numeric_limits::max(); #pragma unroll for (int i = 0; i < kNumRDMARanks; ++ i) if (not forward_channel_retired[i]) - min_tail = min(min_tail, forward_channel_nvl_tail[i][lane_id]); + min_tail = min(min_tail, forward_channel_nvl_tail[i][dst_nvl_rank]); if (__all_sync(0xffffffff, min_tail == std::numeric_limits::max())) break; // Update remote tail - if (min_tail != std::numeric_limits::max() and min_tail >= last_tail + num_max_nvl_chunked_send_tokens) + // TODO: control update interval + if (lane_id < NUM_MAX_NVL_PEERS and min_tail != std::numeric_limits::max() and min_tail > last_tail) st_release_sys_global(nvl_channel_tail.buffer(), min_tail); + __syncwarp(); } } else { // NVL consumers