mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fix AR bugs for normal kernels
This commit is contained in:
@@ -372,7 +372,6 @@ nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
|
||||
// NOTES: only one thread can run this function
|
||||
// TODO: consider this assertion for normal AR
|
||||
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
|
||||
}
|
||||
|
||||
|
||||
@@ -925,9 +925,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
break;
|
||||
|
||||
// Update remote head
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)
|
||||
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
last_head = min_head;
|
||||
}
|
||||
|
||||
// Nanosleep and let other warps work
|
||||
__nanosleep(NUM_WAIT_NANOSECONDS);
|
||||
@@ -1653,9 +1655,11 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
|
||||
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks)
|
||||
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_rdma_head = min_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
last_rdma_head = min_head;
|
||||
}
|
||||
} else {
|
||||
// Find minimum head for NVL ranks
|
||||
#pragma unroll
|
||||
|
||||
Reference in New Issue
Block a user