From 901cdf79bebbc0e78a954d9e68c8b8329c31d2df Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 23 Jun 2025 17:54:57 +0800 Subject: [PATCH] Fix bugs --- csrc/kernels/internode.cu | 80 ++++++++++++++------------------------- 1 file changed, 29 insertions(+), 51 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 13792c8..c3f14bb 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -365,7 +365,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); const auto role_meta = [=]() -> std::pair { if (is_forwarder) { @@ -423,8 +423,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status __shared__ int rdma_send_channel_lock[kNumRDMARanks]; __shared__ int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ uint32_t rdma_send_channel_wip_window[kNumRDMARanks]; - __shared__ uint32_t rdma_send_channel_rdy_window[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization @@ -465,52 +464,32 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Iterate over tokens and copy into buffer int64_t token_idx; - int cached_rdma_channel_head = 0; + int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); - for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumDispatchRDMASenderWarps) { + for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) { // Read RDMA rank existence uint64_t is_token_in_rank_uint64 = 0; - if (lane_id < kNumRDMARanks) + if (lane_id < kNumRDMARanks) { is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); + global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); + } + __syncwarp(); - // Acquire a tail - int rdma_tail_idx = -1; - if (is_token_in_rank_uint64 != 0) { - while (true) { - // Acquire lock first - acquire_lock(rdma_send_channel_lock + lane_id); + // Skip the token which does not belong to this warp + if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) + continue; + auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; - // If there is no remaining slot, continue - auto wip_window = rdma_send_channel_wip_window[lane_id]; - auto rdy_window = rdma_send_channel_rdy_window[lane_id]; - auto window = wip_window | rdy_window; - auto latest_tail = rdma_send_channel_tail[lane_id]; + // Wait the remote buffer to be released + auto start_time = clock64(); + while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - // The same effect with `EP_DEVICE_ASSERT(window != 0xffffffffu);` - EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps"); - - // Use the first available slot - // NOTES: do not use `std::countr_one`, as it is buggy in CUDA C++ - auto offset = __ffs(~window) - 1; - rdma_tail_idx = latest_tail + offset; - rdma_send_channel_wip_window[lane_id] = wip_window | (1u << offset); - - // Release lock - release_lock(rdma_send_channel_lock + lane_id); - break; - } - - // Wait the remote buffer to be released - auto start_time = clock64(); - while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { - cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); - - // Timeout check - if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx); - trap(); - } + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx); + trap(); } } __syncwarp(); @@ -583,22 +562,22 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv acquire_lock(rdma_send_channel_lock + lane_id); // Release the transaction slot - auto wip_window = rdma_send_channel_wip_window[lane_id]; - auto rdy_window = rdma_send_channel_rdy_window[lane_id]; + auto rdy_window = rdma_send_channel_window[lane_id]; auto latest_tail = rdma_send_channel_tail[lane_id]; auto offset = rdma_tail_idx - latest_tail; - // Erase bit and move the zeros if possible - wip_window ^= 1u << offset; + // The same effect with `EP_DEVICE_ASSERT(offset < 32);` + EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps"); + + // Erase bit and move the ones if possible rdy_window ^= 1u << offset; if (offset == 0) { EP_DEVICE_ASSERT(rdy_window & 1); auto num_empty_slots = __ffs(~rdy_window) - 1; st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); - wip_window >>= num_empty_slots, rdy_window >>= num_empty_slots; + rdy_window >>= num_empty_slots; } - rdma_send_channel_wip_window[lane_id] = wip_window; - rdma_send_channel_rdy_window[lane_id] = rdy_window; + rdma_send_channel_window[lane_id] = rdy_window; // Release lock release_lock(rdma_send_channel_lock + lane_id); @@ -613,8 +592,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; - (lane_id < kNumRDMARanks) ? (rdma_send_channel_wip_window[lane_id] = 0) : 0; - (lane_id < kNumRDMARanks) ? (rdma_send_channel_rdy_window[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; // Synchronize shared memory sync_rdma_sender_smem();