From a086ac55365b39faec65067705dff63e6acc6336 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Fri, 20 Jun 2025 16:25:49 +0800 Subject: [PATCH] Use correct buffer pointers --- csrc/kernels/internode.cu | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 9887a58..fc32a07 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -745,26 +745,31 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Copy data // The `shifted` should be restored shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + auto dst_nvl_channel_x = nvl_channel_x.buffer_by_sync(dst_nvl_rank); UNROLLED_WARP_COPY(5, lane_id, hidden_int4, - nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, + dst_nvl_channel_x + dst_slot_idx * hidden_int4, reinterpret_cast(shifted), ld_nc_global, st_na_global); shifted = static_cast(shifted) + hidden_int4; // Copy source meta + auto dst_nvl_channel_src_meta = nvl_channel_src_meta.buffer_by_sync(dst_nvl_rank); if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); + st_na_global(dst_nvl_channel_src_meta + dst_slot_idx, src_meta); shifted = static_cast(shifted) + 1; // Copy `x_scales` + auto dst_nvl_channel_x_scales = nvl_channel_x_scales.buffer_by_sync(dst_nvl_rank); UNROLLED_WARP_COPY(1, lane_id, num_scales, - nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, + dst_nvl_channel_x_scales + dst_slot_idx * num_scales, reinterpret_cast(shifted), ld_nc_global, st_na_global); shifted = static_cast(shifted) + num_scales; // Copy `topk_idx` and `topk_weights` // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted + auto dst_nvl_channel_topk_idx = nvl_channel_topk_idx.buffer_by_sync(dst_nvl_rank); + auto dst_nvl_channel_topk_weights = nvl_channel_topk_weights.buffer_by_sync(dst_nvl_rank); if (lane_id < num_topk) { // Read auto idx_value = ld_nc_global(static_cast(shifted) + lane_id); @@ -775,9 +780,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const auto dst_rank_expert_begin = (rdma_rank + NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_experts_per_rank; const auto dst_rank_expert_end = dst_rank_expert_begin + num_experts_per_rank; idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; - st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); + st_na_global(dst_nvl_channel_topk_idx + dst_slot_idx * num_topk + lane_id, idx_value); weight_value = idx_value >= 0 ? weight_value : 0.0f; - st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); + st_na_global(dst_nvl_channel_topk_weights + dst_slot_idx * num_topk + lane_id, weight_value); } // Move tail index @@ -883,8 +888,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { - printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", - channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d, remain: %d\n", + channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx, num_tokens_to_recv); trap(); } }