mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-05 20:44:48 +00:00
Use put_nbi_warp
.
This commit is contained in:
parent
3b1045db43
commit
e255d57bef
@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
template <bool kAlwaysDoPostSend = false>
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
|
||||
// Submit
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
ibgda_submit_requests<kAlwaysDoPostSend>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
int dst_pe, int qp_id, bool is_local_copy) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
// TODO: rewrite local API copy with unrolling and vectorization
|
||||
nvshmem_uint8_put_nbi(reinterpret_cast<uint8_t*>(req_rptr), reinterpret_cast<uint8_t*>(req_lptr), bytes, dst_pe);
|
||||
} else {
|
||||
uint32_t num_wqes = 0;
|
||||
uint64_t base_wqe_idx = 0;
|
||||
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
while (bytes > 0) {
|
||||
__be32 lkey, rkey;
|
||||
uint64_t laddr, raddr, chunk_size;
|
||||
|
||||
chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey));
|
||||
bytes -= chunk_size;
|
||||
|
||||
auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
|
||||
|
||||
// Only the last WQE should send imm
|
||||
ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr);
|
||||
|
||||
req_lptr += chunk_size;
|
||||
req_rptr += chunk_size;
|
||||
if ((num_wqes ++) == 0)
|
||||
base_wqe_idx = wqe_idx;
|
||||
}
|
||||
|
||||
ibgda_submit_requests<true>(qp, base_wqe_idx, num_wqes);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
__syncwarp();
|
||||
|
||||
// Issue RDMA for non-local ranks
|
||||
if (dst_rdma_rank != rdma_rank and lane_id == 0) {
|
||||
nvshmemi_ibgda_put_nbi_thread(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_channel_meta.recv_buffer(rdma_rank)),
|
||||
reinterpret_cast<uint64_t>(rdma_channel_meta.send_buffer(dst_rdma_rank)),
|
||||
sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
|
||||
channel_id, false);
|
||||
channel_id, lane_id, 0);
|
||||
}
|
||||
}
|
||||
sync_rdma_sender_smem();
|
||||
@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
|
||||
} else {
|
||||
// Lighter fence for local RDMA rank
|
||||
memory_fence();
|
||||
@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token;
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
if (lane_id == 0) {
|
||||
// TODO: use the full warp to do this
|
||||
nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, false);
|
||||
}
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0);
|
||||
} else {
|
||||
memory_fence();
|
||||
}
|
||||
|
@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
// Wait local sends issued and send expert counts
|
||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||
if (dst_rank != rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false);
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||
} else {
|
||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user