Use put_nbi_warp.

This commit is contained in:
Shangyan Zhou 2025-04-22 12:29:46 +08:00
parent 3b1045db43
commit e255d57bef
3 changed files with 13 additions and 52 deletions

View File

@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
template <bool kAlwaysDoPostSend = false>
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
// Submit
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
}
@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
}
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
int dst_pe, int qp_id, bool is_local_copy) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
// TODO: rewrite local API copy with unrolling and vectorization
nvshmem_uint8_put_nbi(reinterpret_cast<uint8_t*>(req_rptr), reinterpret_cast<uint8_t*>(req_lptr), bytes, dst_pe);
} else {
uint32_t num_wqes = 0;
uint64_t base_wqe_idx = 0;
auto qp = ibgda_get_rc(dst_pe, qp_id);
while (bytes > 0) {
__be32 lkey, rkey;
uint64_t laddr, raddr, chunk_size;
chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey));
bytes -= chunk_size;
auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
// Only the last WQE should send imm
ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr);
req_lptr += chunk_size;
req_rptr += chunk_size;
if ((num_wqes ++) == 0)
base_wqe_idx = wqe_idx;
}
ibgda_submit_requests<true>(qp, base_wqe_idx, num_wqes);
}
}
} // namespace deep_ep

View File

@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
__syncwarp();
// Issue RDMA for non-local ranks
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, false);
channel_id, lane_id, 0);
}
}
sync_rdma_sender_smem();
@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
if (lane_id == dst_rdma_rank) {
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
}
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
// Lighter fence for local RDMA rank
memory_fence();
@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights,
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
if (lane_id == 0) {
// TODO: use the full warp to do this
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
}
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
} else {
memory_fence();
}

View File

@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false);
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}