Fix the shifted buffer pointer

This commit is contained in:
Chenggang Zhao
2025-06-20 11:31:57 +08:00
parent cd5c57fb2a
commit 8da790e3f3

View File

@@ -743,6 +743,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx % num_max_nvl_chunked_recv_tokens, 0);
// Copy data
// The `shifted` should be restored
shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),