mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Refactor some code.
This commit is contained in:
@@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
nvshmemx_signal_op(reinterpret_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
} else {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
|
||||
__be32 rkey;
|
||||
uint64_t raddr;
|
||||
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
|
||||
|
||||
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
|
||||
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
|
||||
|
||||
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
|
||||
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
|
||||
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
|
||||
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
|
||||
|
||||
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
|
||||
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
int dst_pe, int qp_id, bool is_local_copy) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
// TODO: rewrite local API copy with unrolling and vectorization
|
||||
nvshmem_uint8_put_nbi(reinterpret_cast<uint8_t*>(req_rptr), reinterpret_cast<uint8_t*>(req_lptr), bytes, dst_pe);
|
||||
} else {
|
||||
uint32_t num_wqes = 0;
|
||||
uint64_t base_wqe_idx = 0;
|
||||
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
while (bytes > 0) {
|
||||
__be32 lkey, rkey;
|
||||
uint64_t laddr, raddr, chunk_size;
|
||||
|
||||
chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey));
|
||||
bytes -= chunk_size;
|
||||
|
||||
auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
|
||||
|
||||
// Only the last WQE should send imm
|
||||
ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr);
|
||||
|
||||
req_lptr += chunk_size;
|
||||
req_rptr += chunk_size;
|
||||
if ((num_wqes ++) == 0)
|
||||
base_wqe_idx = wqe_idx;
|
||||
}
|
||||
|
||||
ibgda_submit_requests<true>(qp, base_wqe_idx, num_wqes);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
||||
Reference in New Issue
Block a user