Refactor some code.

This commit is contained in:
Shangyan Zhou
2025-04-22 10:22:30 +08:00
parent c07fdd197c
commit 20b2aaaf9e
4 changed files with 90 additions and 61 deletions

View File

@@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
}
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op(reinterpret_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
} else {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
__be32 rkey;
uint64_t raddr;
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
}
}
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
int dst_pe, int qp_id, bool is_local_copy) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
// TODO: rewrite local API copy with unrolling and vectorization
nvshmem_uint8_put_nbi(reinterpret_cast<uint8_t*>(req_rptr), reinterpret_cast<uint8_t*>(req_lptr), bytes, dst_pe);
} else {
uint32_t num_wqes = 0;
uint64_t base_wqe_idx = 0;
auto qp = ibgda_get_rc(dst_pe, qp_id);
while (bytes > 0) {
__be32 lkey, rkey;
uint64_t laddr, raddr, chunk_size;
chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey));
bytes -= chunk_size;
auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx);
// Only the last WQE should send imm
ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr);
req_lptr += chunk_size;
req_rptr += chunk_size;
if ((num_wqes ++) == 0)
base_wqe_idx = wqe_idx;
}
ibgda_submit_requests<true>(qp, base_wqe_idx, num_wqes);
}
}
} // namespace deep_ep