diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 6a59c8c..13792c8 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -423,7 +423,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status __shared__ int rdma_send_channel_lock[kNumRDMARanks]; __shared__ int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_wip_window[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_rdy_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization @@ -480,18 +481,19 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv acquire_lock(rdma_send_channel_lock + lane_id); // If there is no remaining slot, continue - auto window = rdma_send_channel_window[lane_id]; + auto wip_window = rdma_send_channel_wip_window[lane_id]; + auto rdy_window = rdma_send_channel_rdy_window[lane_id]; + auto window = wip_window | rdy_window; auto latest_tail = rdma_send_channel_tail[lane_id]; - if (window == 0xffffffffu) { - release_lock(rdma_send_channel_lock + lane_id); - continue; - } + + // The same effect with `EP_DEVICE_ASSERT(window != 0xffffffffu);` + EP_STATIC_ASSERT(kNumDispatchRDMASenderWarps < 32, "Invalid warps"); // Use the first available slot // NOTES: do not use `std::countr_one`, as it is buggy in CUDA C++ auto offset = __ffs(~window) - 1; rdma_tail_idx = latest_tail + offset; - rdma_send_channel_window[lane_id] = window | (1u << offset); + rdma_send_channel_wip_window[lane_id] = wip_window | (1u << offset); // Release lock release_lock(rdma_send_channel_lock + lane_id); @@ -581,21 +583,22 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv acquire_lock(rdma_send_channel_lock + lane_id); // Release the transaction slot - auto window = rdma_send_channel_window[lane_id]; + auto wip_window = rdma_send_channel_wip_window[lane_id]; + auto rdy_window = rdma_send_channel_rdy_window[lane_id]; auto latest_tail = rdma_send_channel_tail[lane_id]; auto offset = rdma_tail_idx - latest_tail; - // TODO: remove this assertion after debugging - EP_DEVICE_ASSERT((window >> offset) & 1 and (window & 1)); - // Erase bit and move the zeros if possible - window ^= 1u << offset; + wip_window ^= 1u << offset; + rdy_window ^= 1u << offset; if (offset == 0) { - auto num_empty_slots = window == 0 ? 1 : (__ffs(window) - 1); + EP_DEVICE_ASSERT(rdy_window & 1); + auto num_empty_slots = __ffs(~rdy_window) - 1; st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); - window >>= num_empty_slots; + wip_window >>= num_empty_slots, rdy_window >>= num_empty_slots; } - rdma_send_channel_window[lane_id] = window; + rdma_send_channel_wip_window[lane_id] = wip_window; + rdma_send_channel_rdy_window[lane_id] = rdy_window; // Release lock release_lock(rdma_send_channel_lock + lane_id); @@ -610,7 +613,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; - (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_wip_window[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_rdy_window[lane_id] = 0) : 0; // Synchronize shared memory sync_rdma_sender_smem();