mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use correct buffer pointers
This commit is contained in:
@@ -745,26 +745,31 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
// Copy data
|
||||
// The `shifted` should be restored
|
||||
shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
|
||||
auto dst_nvl_channel_x = nvl_channel_x.buffer_by_sync(dst_nvl_rank);
|
||||
UNROLLED_WARP_COPY(5, lane_id, hidden_int4,
|
||||
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
|
||||
dst_nvl_channel_x + dst_slot_idx * hidden_int4,
|
||||
reinterpret_cast<int4*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = static_cast<int4*>(shifted) + hidden_int4;
|
||||
|
||||
// Copy source meta
|
||||
auto dst_nvl_channel_src_meta = nvl_channel_src_meta.buffer_by_sync(dst_nvl_rank);
|
||||
if (lane_id == 0)
|
||||
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
|
||||
st_na_global(dst_nvl_channel_src_meta + dst_slot_idx, src_meta);
|
||||
shifted = static_cast<SourceMeta*>(shifted) + 1;
|
||||
|
||||
// Copy `x_scales`
|
||||
auto dst_nvl_channel_x_scales = nvl_channel_x_scales.buffer_by_sync(dst_nvl_rank);
|
||||
UNROLLED_WARP_COPY(1, lane_id, num_scales,
|
||||
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
|
||||
dst_nvl_channel_x_scales + dst_slot_idx * num_scales,
|
||||
reinterpret_cast<float*>(shifted),
|
||||
ld_nc_global, st_na_global);
|
||||
shifted = static_cast<float*>(shifted) + num_scales;
|
||||
|
||||
// Copy `topk_idx` and `topk_weights`
|
||||
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
|
||||
auto dst_nvl_channel_topk_idx = nvl_channel_topk_idx.buffer_by_sync(dst_nvl_rank);
|
||||
auto dst_nvl_channel_topk_weights = nvl_channel_topk_weights.buffer_by_sync(dst_nvl_rank);
|
||||
if (lane_id < num_topk) {
|
||||
// Read
|
||||
auto idx_value = ld_nc_global(static_cast<int*>(shifted) + lane_id);
|
||||
@@ -775,9 +780,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const auto dst_rank_expert_begin = (rdma_rank + NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_experts_per_rank;
|
||||
const auto dst_rank_expert_end = dst_rank_expert_begin + num_experts_per_rank;
|
||||
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
|
||||
st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value);
|
||||
st_na_global(dst_nvl_channel_topk_idx + dst_slot_idx * num_topk + lane_id, idx_value);
|
||||
weight_value = idx_value >= 0 ? weight_value : 0.0f;
|
||||
st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value);
|
||||
st_na_global(dst_nvl_channel_topk_weights + dst_slot_idx * num_topk + lane_id, weight_value);
|
||||
}
|
||||
|
||||
// Move tail index
|
||||
@@ -883,8 +888,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
|
||||
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx);
|
||||
printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d, remain: %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx, num_tokens_to_recv);
|
||||
trap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user