diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 2200a07..9f8c37c 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, // Submit if (lane_id == 0) - ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); + ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); } @@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons } } -__device__ static __forceinline__ void -nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, - int dst_pe, int qp_id, bool is_local_copy) { - if (is_local_copy) { - // Fallback to NVSHMEM legacy API - // TODO: rewrite local API copy with unrolling and vectorization - nvshmem_uint8_put_nbi(reinterpret_cast(req_rptr), reinterpret_cast(req_lptr), bytes, dst_pe); - } else { - uint32_t num_wqes = 0; - uint64_t base_wqe_idx = 0; - - auto qp = ibgda_get_rc(dst_pe, qp_id); - while (bytes > 0) { - __be32 lkey, rkey; - uint64_t laddr, raddr, chunk_size; - - chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey)); - bytes -= chunk_size; - - auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1); - auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); - - // Only the last WQE should send imm - ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr); - - req_lptr += chunk_size; - req_rptr += chunk_size; - if ((num_wqes ++) == 0) - base_wqe_idx = wqe_idx; - } - - ibgda_submit_requests(qp, base_wqe_idx, num_wqes); - } -} - } // namespace deep_ep diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 0b59d1f..2e77460 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv __syncwarp(); // Issue RDMA for non-local ranks - if (dst_rdma_rank != rdma_rank and lane_id == 0) { - nvshmemi_ibgda_put_nbi_thread(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), - reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), - sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), - channel_id, false); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, lane_id, 0); } } sync_rdma_sender_smem(); @@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - if (lane_id == dst_rdma_rank) { - nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); - } + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights, const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - if (lane_id == 0) { - // TODO: use the full warp to do this - nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); - } + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { memory_fence(); } diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 74fe0bd..8e0d9e4 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); }