From a0a6e22eff25291ed57031428b0d1bf1c1cce10d Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Thu, 19 Jun 2025 13:48:07 +0800 Subject: [PATCH] Fully remove forwarders' and NVL receivers' code --- csrc/kernels/buffer.cuh | 25 ++- csrc/kernels/internode.cu | 360 +++++++------------------------------- csrc/kernels/utils.cuh | 2 +- 3 files changed, 78 insertions(+), 309 deletions(-) diff --git a/csrc/kernels/buffer.cuh b/csrc/kernels/buffer.cuh index 19400a5..6a64519 100644 --- a/csrc/kernels/buffer.cuh +++ b/csrc/kernels/buffer.cuh @@ -2,6 +2,7 @@ #include "configs.cuh" #include "exception.cuh" +#include "utils.cuh" namespace deep_ep { @@ -45,25 +46,26 @@ public: int total_bytes; __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { + int channel_id = 0, int num_channels = 1, int offset = 0) { EP_STATIC_ASSERT(kNumRanks == 1, ""); num_bytes = num_elems * sizeof(dtype_t); int per_channel_bytes = num_bytes * num_ranks; - total_bytes = per_channel_bytes * num_sms; - ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + total_bytes = per_channel_bytes * num_channels; + ptrs[0] = static_cast(gbl_ptr) + per_channel_bytes * channel_id + num_bytes * offset; gbl_ptr = static_cast(gbl_ptr) + total_bytes; } __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, - int sm_id = 0, int num_sms = 1, int offset = 0) { + int channel_id = 0, int num_channels = 1, int offset = 0) { + // TODO: use UR as much as possible EP_STATIC_ASSERT(kNumRanks > 1, ""); num_bytes = num_elems * sizeof(dtype_t); int per_channel_bytes = num_bytes * num_ranks; - total_bytes = per_channel_bytes * num_sms; + total_bytes = per_channel_bytes * num_channels; for (int i = 0; i < kNumRanks; ++ i) { - ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + ptrs[i] = static_cast(gbl_ptrs[i]) + per_channel_bytes * channel_id + num_bytes * offset; gbl_ptrs[i] = static_cast(gbl_ptrs[i]) + total_bytes; } } @@ -86,15 +88,22 @@ public: return *this; } - __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + __device__ __forceinline__ dtype_t* buffer(const int& idx = 0) { EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[0] + num_bytes * idx); } - __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, const int& idx = 0) { EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); } + + __device__ __forceinline__ dtype_t* buffer_by_sync(int rank_idx, const int& idx = 0) { + // Different lanes store different pointers + // NOTES: this function requires the whole warp + EP_STATIC_ASSERT(kNumRanks == 1, "Invalid number of ranks"); + return broadcast(reinterpret_cast(ptrs[0] + num_bytes * idx), rank_idx); + } }; template diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index a49c430..1ed78c3 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -365,14 +365,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || ibgda_get_state()->num_rc_per_pe >= num_sms); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); const auto role_meta = [=]() -> std::pair { if (is_forwarder) { - if (warp_id < NUM_MAX_NVL_PEERS) { - return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + // TODO: a warp may be responsible for multiple RDMA ranks + if (warp_id < kNumRDMARanks) { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % kNumRDMARanks}; } else { - return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + return {WarpRole::kForwarderCoordinator, warp_id - kNumRDMARanks}; } } else if (warp_id < kNumDispatchRDMASenderWarps) { return {WarpRole::kRDMASender, -1}; @@ -398,26 +399,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); - // NVL buffer layouts - // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" - void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; - int rs_wr_rank = 0, ws_rr_rank = 0; - if (warp_role == WarpRole::kRDMAAndNVLForwarder) - rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; - if (warp_role == WarpRole::kNVLReceivers) - rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; - - // Allocate buffers - auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); - auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); - // RDMA sender warp synchronization __shared__ volatile int rdma_send_next_token_idx; __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; @@ -639,301 +620,79 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv } } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { // RDMA consumers and NVL producers - const auto dst_nvl_rank = target_rank; - const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; - const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); - const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); + // Each warp is responsible for a source RDMA rank + // TODO: responsible for multiple RDMA ranks + const auto src_rdma_rank = target_rank; + + // NVL buffers + auto buffer_ptr = lane_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[lane_id] : nullptr; + auto nvl_channel_x = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_src_meta = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_x_scales = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_topk_idx = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_topk_weights = AsymBuffer(buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_prefix_start = AsymBuffer(buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + auto nvl_channel_prefix_end = AsymBuffer(buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + + // TODO: use NVL head/tail for coordinators + // auto nvl_channel_head = AsymBuffer(buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); + // auto nvl_channel_tail = AsymBuffer(buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank); // Wait counters to arrive - int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; - EP_DEVICE_ASSERT(kNumRDMARanks <= 32); auto start_time = clock64(); - if (lane_id < kNumRDMARanks) { - while (true) { - auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); - auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); - auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); - auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); - if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { - // Notify NVL ranks - int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; - EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); - st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); - st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); + int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVLink peers"); + while (true) { + auto meta_0 = lane_id < NUM_MAX_NVL_PEERS ? ld_volatile_global(rdma_channel_meta.recv_buffer(src_rdma_rank) + lane_id) : -1; + auto meta_1 = lane_id < NUM_MAX_NVL_PEERS ? ld_volatile_global(rdma_channel_meta.recv_buffer(src_rdma_rank) + NUM_MAX_NVL_PEERS + lane_id) : -1; + auto meta_2 = lane_id == 0 ? ld_volatile_global(rdma_channel_meta.recv_buffer(src_rdma_rank) + NUM_MAX_NVL_PEERS * 2) : -1; + auto meta_3 = lane_id == 0 ? ld_volatile_global(rdma_channel_meta.recv_buffer(src_rdma_rank) + NUM_MAX_NVL_PEERS * 2 + 1) : -1; + if (__all_sync(0xffffffff, meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0)) { + int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; - // Save RDMA channel received token count + // Notify NVL ranks + int *dst_start_ptr, *dst_end_ptr; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) { + auto start_ptr = nvl_channel_prefix_start.buffer_by_sync(i) + src_rdma_rank; + auto end_ptr = nvl_channel_prefix_end.buffer_by_sync(i) + src_rdma_rank; + dst_start_ptr = i == lane_id ? start_ptr : dst_start_ptr; + dst_end_ptr = i == lane_id ? end_ptr : dst_end_ptr; + } + if (lane_id < NUM_MAX_NVL_PEERS) { + st_relaxed_sys_global(dst_start_ptr, -start_sum - 1); + st_relaxed_sys_global(dst_end_ptr, -end_sum - 1); + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + } + __syncwarp(); + + // Save RDMA channel received token count + if (lane_id == 0) { src_rdma_channel_prefix = -meta_2 - 1; auto src_rdma_channel_prefix_1 = -meta_3 - 1; num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; if (not kCachedMode) - recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; - src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + recv_rdma_channel_prefix_matrix[src_rdma_rank * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += src_rdma_rank == 0 ? 0 : recv_rdma_rank_prefix_sum[src_rdma_rank - 1]; EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); - break; } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); - trap(); - } - } - } - __syncwarp(); - - // Shift cached head - send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; - - // Wait shared memory to be cleaned - sync_forwarder_smem(); - - // Forward tokens from RDMA buffer - // NOTES: always start from the local rank - int src_rdma_rank = sm_id % kNumRDMARanks; - int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; - int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; - while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { - // Check destination queue emptiness, or wait a buffer to be released - start_time = clock64(); - while (lane_id == 0) { - int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; - if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) - break; - cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); - trap(); - } - } - __syncwarp(); - - // Find next source RDMA rank (round-robin) - start_time = clock64(); - while (true) { - src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; - if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { - if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) - cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); - if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) - break; - } - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { - printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", - channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); - trap(); - } - } - auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); - auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); - - // Iterate over every token from the RDMA buffer - for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { - auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; - void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; - auto src_meta = ld_nc_global(reinterpret_cast(static_cast(shifted) + hidden_bytes)); - lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; - bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); - if (lane_id == src_rdma_rank) { - auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; - rdma_nvl_token_idx += is_in_dst_nvl_rank; - if (not kCachedMode) - send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; - } - if (not is_in_dst_nvl_rank) - continue; - - // Get an empty slot - int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens; - - // Copy data - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, - nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, - reinterpret_cast(shifted), - ld_nc_global, st_na_global); - shifted = static_cast(shifted) + hidden_int4; - - // Copy source meta - if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); - shifted = static_cast(shifted) + 1; - - // Copy `x_scales` - UNROLLED_WARP_COPY(1, lane_id, num_scales, - nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, - reinterpret_cast(shifted), - ld_nc_global, st_na_global); - shifted = static_cast(shifted) + num_scales; - - // Copy `topk_idx` and `topk_weights` - // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted - if (lane_id < num_topk) { - // Read - auto idx_value = ld_nc_global(static_cast(shifted) + lane_id); - shifted = static_cast(shifted) + num_topk; - auto weight_value = ld_nc_global(static_cast(shifted) + lane_id); - - // Transform and write - idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; - st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); - weight_value = idx_value >= 0 ? weight_value : 0.0f; - st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); - } - - // In case of insufficient NVL buffers, early stopping - if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens) - src_rdma_tail = i + 1; - } - - // Sync head index - if (lane_id == src_rdma_rank) - forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); - - // Move tail index - __syncwarp(); - if (lane_id == 0) - st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); - } - - // Retired - __syncwarp(); - if (lane_id == 0) - forward_channel_retired[dst_nvl_rank] = true; - } else if (warp_role == WarpRole::kForwarderCoordinator) { - // Extra warps for forwarder coordinator should exit directly - if (target_rank > 0) - return; - - // Forward warp coordinator - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - - // Clean shared memory - EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); - #pragma unroll - for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) - forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; - if (lane_id < NUM_MAX_NVL_PEERS) - forward_channel_retired[lane_id] = false; - sync_forwarder_smem(); - - int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; - while (true) { - // Find minimum head - int min_head = std::numeric_limits::max(); - #pragma unroll - for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i]) - min_head = min(min_head, forward_channel_head[i][target_rdma]); - if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) - break; - - // Update remote head - if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, - translate_dst_rdma_rank(lane_id, nvl_rank), channel_id + num_channels, lane_id == rdma_rank); - last_head = min_head; - } - - // Nanosleep and let other warps work - __nanosleep(NUM_WAIT_NANOSECONDS); - } - } else { - // NVL consumers - // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) - int src_nvl_rank = target_rank, total_offset = 0; - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); - if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) - total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; - - // Receive channel offsets - int start_offset = 0, end_offset = 0, num_tokens_to_recv; - auto start_time = clock64(); - while (lane_id < kNumRDMARanks) { - start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); - end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); - if (start_offset < 0 and end_offset < 0) { - start_offset = -start_offset - 1, end_offset = -end_offset - 1; - total_offset += start_offset; break; } // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", - channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); + printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, NVL: %d, source RDMA: %d, meta: %d, %d, %d, %d\n", + channel_id, rdma_rank, nvl_rank, src_rdma_rank, meta_0, meta_1, meta_2, meta_3); trap(); } } - num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); - - // Save for combine usage - if (lane_id < kNumRDMARanks and not kCachedMode) - recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; __syncwarp(); - - int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; - while (num_tokens_to_recv > 0) { - // Check channel status by lane 0 - start_time = clock64(); - while (lane_id == 0) { - // Ready to copy - if (cached_channel_head_idx != cached_channel_tail_idx) - break; - cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); - - // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); - trap(); - } - } - - // Sync queue tail - cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); - - // Copy data - int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; - for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { - int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; - auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); - int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); - (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; - - // Copy data - UNROLLED_WARP_COPY(5, lane_id, hidden_int4, - recv_x + recv_token_idx * hidden_int4, - nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, - ld_nc_global, st_na_global); - - // Copy source meta - if (lane_id == 0 and not kCachedMode) - st_na_global(recv_src_meta + recv_token_idx, meta); - - // Copy scales - UNROLLED_WARP_COPY(1, lane_id, num_scales, - recv_x_scales + recv_token_idx * num_scales, - nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, - ld_nc_global, st_na_global); - - // Copy `topk_idx` and `topk_weights` - if (lane_id < num_topk) { - auto recv_idx = recv_token_idx * num_topk + lane_id; - auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; - st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); - st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); - } - } - - // Move queue - __syncwarp(); - if (lane_id == 0) - st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); - } + } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Extra warps for forwarder coordinator should exit directly + if (target_rank > 0) + return; + } else { + // NVL consumers } } @@ -976,6 +735,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + EP_HOST_ASSERT(num_ranks / NUM_MAX_NVL_PEERS <= NUM_MAX_NVL_PEERS); SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 796a6f9..b0173b9 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -366,7 +366,7 @@ __device__ __forceinline__ void unpack2(const dtype_b_t& packed, dtype_a_t& x, d } template -__device__ __forceinline__ dtype_t broadcast(dtype_t& ptr, int src_lane_idx) { +__device__ __forceinline__ dtype_t broadcast(dtype_t ptr, int src_lane_idx) { EP_STATIC_ASSERT(sizeof(dtype_t) % sizeof(int) == 0, ""); auto send_int_values = reinterpret_cast(&ptr); int recv_int_values[sizeof(dtype_t) / sizeof(int)];