From 1c277c303eefa56f6d3d47ab31b1b9c0f9dfd734 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 23 Jun 2025 11:45:05 +0800 Subject: [PATCH] Add draft --- csrc/kernels/internode.cu | 110 ++++++++++++++++++++++++-------------- csrc/kernels/utils.cuh | 22 ++++++++ 2 files changed, 92 insertions(+), 40 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index a49c430..4329818 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -419,9 +419,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); // RDMA sender warp synchronization - __shared__ volatile int rdma_send_next_token_idx; - __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + // NOTES: `rdma_send_channel_tail` means the latest released tail + // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status + __shared__ int rdma_send_channel_lock[kNumRDMARanks]; + __shared__ int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; // Forward warp synchronization @@ -434,12 +436,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - // Clean shared memory - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); - (warp_id == 0 and lane_id == 0) ? (rdma_send_next_token_idx = token_start_idx) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; - (warp_id == 0 and lane_id < kNumRDMARanks) ? (rdma_send_channel_next_tail[lane_id] = 0) : 0; - // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { @@ -476,14 +472,31 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id < kNumRDMARanks) is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS); - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); - __syncwarp(); - - // Acquire next tail + // Acquire a tail int rdma_tail_idx = -1; if (is_token_in_rank_uint64 != 0) { - rdma_tail_idx = rdma_send_channel_next_tail[lane_id] ++; + do { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + + // If there is no remaining slot, continue + auto window = rdma_send_channel_window[lane_id]; + auto latest_tail = rdma_send_channel_tail[lane_id]; + if (window == 0xffffffffu) { + release_lock(rdma_send_channel_lock + lane_id); + continue; + } + + // Use the first available slot + auto offset = std::__countr_one(window); + rdma_tail_idx = latest_tail + offset; + rdma_send_channel_window[lane_id] = window | (1u << offset); + + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } while (rdma_tail_idx == -1); + + // Wait the remote buffer to be released while (rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); } @@ -493,14 +506,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id < kNumRDMARanks and not kCachedMode) send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - last_rdma_tail_idx = rdma_tail_idx; - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; - // Broadcast tails SourceMeta src_meta; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; @@ -557,24 +562,46 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); } + __syncwarp(); + + // Release the transaction in the window + if (is_token_in_rank_uint64 != 0) { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + + // Release the transaction slot + auto window = rdma_send_channel_window[lane_id]; + auto latest_tail = rdma_send_channel_tail[lane_id]; + auto offset = rdma_tail_idx - last_rdma_tail_idx; + + // Erase bit and move the zeros if possible + window ^= 1u << offset; + if (offset == 0) { + auto num_empty_slots = std::__countr_zero(window); + num_empty_slots = num_empty_slots == 32 ? 1 : num_empty_slots; + st_release_cta(rdma_send_channel_tail + lane_id, last_rdma_tail_idx + num_empty_slots); + window >>= num_empty_slots; + } + rdma_send_channel_window[lane_id] = window; + + // TODO: remove this assertion after debugging + EP_DEVICE_ASSERT((window >> offset) & 1); + + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } + __syncwarp(); } - - // Epilogue - // Acquire sequential lock - while (lane_id == 0 and rdma_send_next_token_idx != token_idx); - __syncwarp(); - - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), last_rdma_tail_idx + 1); - __syncwarp(); - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; } else if (warp_role == WarpRole::kRDMASenderCoordinator) { // NOTES: in case of splitting, the issued put at the end of the buffer EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + // Synchronize shared memory sync_rdma_sender_smem(); @@ -592,10 +619,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { - printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n", + printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n", channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); trap(); } + + // TODO: try thread-level `put_nbi`? for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) { // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; @@ -603,9 +632,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (synced_num_tokens_to_send == 0) continue; - // Read progress - auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); + // Read the latest progress + // NOTES: `rdma_send_channel_tail` does not need to be protected by lock auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)), 0); + auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); auto num_tokens_processed = processed_tail - synced_last_issued_tail; if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) continue; @@ -625,9 +655,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Lighter fence for local RDMA rank memory_fence(); } + __syncwarp(); // Update tails - __syncwarp(); if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 796a6f9..ac97896 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -466,4 +466,26 @@ barrier_block(int** barrier_signal_ptrs, int rank) { __syncthreads(); } +__forceinline__ __device__ int atomic_cas_cta_acquire(int* addr, int x, int y) { + int ret; + asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" : "=r"(ret) : "l"(addr), "r"(x), "r"(y) : "memory"); + return ret; +} + +__forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) { + int ret; + asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" : "=r"(ret) : "l"(addr), "r"(x) : "memory"); + return ret; +} + +__forceinline__ __device__ void acquire_lock(int* mutex) { + // To make later memory operations valid, we must use `acquire` for memory semantics + while (atomic_cas_cta_acquire(mutex, 0, 1) != 0); +} + +__forceinline__ __device__ void release_lock(int* mutex) { + // To make previous memory operations visible to other threads, we must use `release` for memory semantics + atomic_exch_cta_release(mutex, 0); +} + } // namespace deep_ep