Fix several bugs

This commit is contained in:
Chenggang Zhao
2025-06-20 10:57:56 +08:00
parent 177e491e92
commit 49b9084268

View File

@@ -401,18 +401,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// NVL buffer layouts
// NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers"
void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr;
int rs_wr_rank = 0, ws_rr_rank = 0;
if (warp_role == WarpRole::kRDMAAndNVLForwarder) {
rs_wr_buffer_ptr = buffer_ptrs[nvl_rank];
ws_rr_buffer_ptr = buffer_ptrs[lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0];
rs_wr_rank = nvl_rank, ws_rr_rank = target_rank;
}
if (warp_role == WarpRole::kNVLReceivers) {
rs_wr_buffer_ptr = buffer_ptrs[target_rank];
ws_rr_buffer_ptr = buffer_ptrs[nvl_rank];
if (warp_role == WarpRole::kRDMAAndNVLForwarder)
rs_wr_rank = nvl_rank, ws_rr_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0;
if (warp_role == WarpRole::kNVLReceivers)
rs_wr_rank = target_rank, ws_rr_rank = nvl_rank;
}
auto rs_wr_buffer_ptr = buffer_ptrs[rs_wr_rank];
auto ws_rr_buffer_ptr = buffer_ptrs[ws_rr_rank];
// Allocate buffers
auto nvl_channel_x = AsymBuffer<int4>(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr);
@@ -665,17 +660,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1;
// Notify NVL ranks
int *dst_start_ptr, *dst_end_ptr;
#pragma unroll
for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) {
auto start_ptr = nvl_channel_prefix_start.buffer_by_sync(i) + src_rdma_rank;
auto end_ptr = nvl_channel_prefix_end.buffer_by_sync(i) + src_rdma_rank;
dst_start_ptr = i == lane_id ? start_ptr : dst_start_ptr;
dst_end_ptr = i == lane_id ? end_ptr : dst_end_ptr;
}
if (lane_id < NUM_MAX_NVL_PEERS) {
st_relaxed_sys_global(dst_start_ptr, -start_sum - 1);
st_relaxed_sys_global(dst_end_ptr, -end_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + src_rdma_rank, -start_sum - 1);
st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + src_rdma_rank, -end_sum - 1);
EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum);
}
__syncwarp();
@@ -701,14 +688,13 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
trap();
}
}
__syncwarp();
// Wait shared memory to be cleaned
sync_forwarder_smem();
// Forward tokens from RDMA buffer
int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0;
int cached_nvl_channel_head = 0;
int cached_nvl_channel_head = 0; // This value is used only by lane 0
while (cached_rdma_channel_tail < num_tokens_to_recv_from_rdma) {
// Wait data arrival
start_time = clock64();
@@ -716,9 +702,9 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
cached_rdma_channel_tail = static_cast<int>(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank)));
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
if (clock64() - start_time > NUM_TIMEOUT_CYCLES) {
printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, NVL: %d, src RDMA rank: %d, head: %d, tail: %d, expected: %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
channel_id, rdma_rank, nvl_rank, src_rdma_rank, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma);
trap();
}
}
@@ -740,7 +726,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Skip not selected tokens
if (not is_in_dst_nvl_rank) {
kCachedMode ? (*send_nvl_head = -1) : 0;
kCachedMode ? (*shifted_send_nvl_head = -1) : 0;
continue;
}
@@ -748,7 +734,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int dst_slot_idx;
if (lane_id == 0) {
dst_slot_idx = atomicAdd_block(const_cast<int*>(forward_channel_nvl_tail_allocator + dst_nvl_rank), 1);
kCachedMode ? (*send_nvl_head = dst_slot_idx) : 0;
kCachedMode ? (*shifted_send_nvl_head = dst_slot_idx) : 0;
while (dst_slot_idx - cached_nvl_channel_head >= num_max_nvl_chunked_recv_tokens)
cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer());
}
@@ -794,6 +780,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
__syncwarp();
if (lane_id == 0)
st_release_cta(forward_channel_nvl_tail[src_rdma_rank] + dst_nvl_rank, dst_slot_idx + 1);
__syncwarp();
}
}
@@ -832,13 +819,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int min_tail = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 0; i < kNumRDMARanks; ++ i) if (not forward_channel_retired[i])
min_tail = min(min_tail, forward_channel_nvl_tail[i][lane_id]);
min_tail = min(min_tail, forward_channel_nvl_tail[i][dst_nvl_rank]);
if (__all_sync(0xffffffff, min_tail == std::numeric_limits<int>::max()))
break;
// Update remote tail
if (min_tail != std::numeric_limits<int>::max() and min_tail >= last_tail + num_max_nvl_chunked_send_tokens)
// TODO: control update interval
if (lane_id < NUM_MAX_NVL_PEERS and min_tail != std::numeric_limits<int>::max() and min_tail > last_tail)
st_release_sys_global(nvl_channel_tail.buffer(), min_tail);
__syncwarp();
}
} else {
// NVL consumers