Fix AR bugs for normal kernels

This commit is contained in:
Chenggang Zhao
2025-03-05 17:13:35 +08:00
parent 680e424bdc
commit 458cdcb22a
5 changed files with 17 additions and 21 deletions

View File

@@ -372,7 +372,6 @@ nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
__device__ static __forceinline__ void
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
// NOTES: only one thread can run this function
// TODO: consider this assertion for normal AR
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
}

View File

@@ -925,9 +925,11 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
break;
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
if (min_head != std::numeric_limits<int>::max() and min_head > last_head and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
last_head = min_head;
}
// Nanosleep and let other warps work
__nanosleep(NUM_WAIT_NANOSECONDS);
@@ -1653,9 +1655,11 @@ combine(int4* combined_x, float* combined_topk_weights,
#pragma unroll
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks)
nvshmem_uint64_p(rdma_channel_head.buffer(rdma_rank), last_rdma_head = min_head,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (min_head != std::numeric_limits<int>::max() and min_head > last_rdma_head and lane_id < kNumRDMARanks) {
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
last_rdma_head = min_head;
}
} else {
// Find minimum head for NVL ranks
#pragma unroll