From 2d0cf41dd1900b105d74cb071f4cac35e3fb6f47 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Fri, 14 Mar 2025 11:04:57 +0800 Subject: [PATCH] Low latency kernels use rdma atomic to support AR. --- csrc/kernels/ibgda_device.cuh | 120 ++++++++++++++++++---------------- csrc/kernels/internode_ll.cu | 26 ++------ csrc/kernels/runtime.cu | 26 +------- 3 files changed, 70 insertions(+), 102 deletions(-) diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 0ef360d..34f0cc4 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) { typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t; +typedef struct { + uint32_t add_data; + uint32_t field_boundary; + uint64_t reserved; +} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; + __device__ static __forceinline__ nvshmemi_ibgda_device_state_t* ibgda_get_state() { return &nvshmemi_ibgda_device_state_d; @@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); } -// Wait until wqe `idx - 1` is completed. -// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv. -// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status. -__device__ static __forceinline__ void -nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) { - auto qp = ibgda_get_rc(dst_pe, qp_id); - auto cq = qp->rx_wq.cq; - - const uint32_t ncqes = cq->ncqes; - auto *cqe64 = reinterpret_cast(cq->cqe); - auto old_cons_idx = *cq->cons_idx; - *cq->cons_idx = old_cons_idx + 1; - - // Wait until `wqe_counter >= old_cons_idx` - while ((static_cast(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes)); -} - __device__ static __forceinline__ void nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits::max()) { // Get rkey @@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } -__device__ static __forceinline__ uint64_t -nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) { - auto mvars = &qp->mvars; - - // Allocate if not enough - constexpr int kMinIBGDARecvs = 32; - auto resv_head = mvars->rx_wq.resv_head; - auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx; - if (num_valid_slots < kMinIBGDARecvs) { - resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes; - mvars->rx_wq.resv_head = resv_head; - - // Ensure WQE is written before `dbrec` - __be32 dbrec_val; - __be32 *dbrec_ptr = qp->rx_wq.dbrec; - - // Compared to sending, for each QP, we only post recv in a single thread, - // so we don't need to do synchronization here - // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))` - asm("{\n\t" - ".reg .b32 dbrec_head_16b;\n\t" - ".reg .b32 ign;\n\t" - "and.b32 dbrec_head_16b, %1, 0xffff;\n\t" - "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t" - "}" : "=r"(dbrec_val) - : "r"(static_cast(resv_head))); - st_na_release(dbrec_ptr, dbrec_val); - } - - // Return old number of slots - return num_valid_slots; -} - -__device__ static __forceinline__ void -nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) { - // NOTES: only one thread can run this function - EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16); -} - __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -419,4 +369,62 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, __syncwarp(); } +__device__ static __forceinline__ void ibgda_write_amo_add_wqe( + nvshmemi_ibgda_device_qp_t *qp, const int &value, + uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, + uint16_t wqe_idx, void **out_wqes) { + ibgda_ctrl_seg_t ctrl_seg = {0}; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_atomic_seg atomic_seg_1; + struct mlx5_wqe_data_seg data_seg; + + auto ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + auto raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto atomic_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + auto data_seg_ptr = reinterpret_cast(reinterpret_cast(atomic_seg_ptr) + sizeof(*atomic_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD` + ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000); + auto atomic_32_masked_fa_seg = reinterpret_cast(&atomic_seg_1); + atomic_32_masked_fa_seg->add_data = HtoBE32(value); + atomic_32_masked_fa_seg->field_boundary = 0; + + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + + data_seg.byte_count = HtoBE32(sizeof(int)); + data_seg.lkey = lkey; + data_seg.addr = HtoBE64(laddr); + + EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization"); + st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + st_na_relaxed(reinterpret_cast(atomic_seg_ptr), *reinterpret_cast(&atomic_seg_1)); + st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) { + nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); + + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + + void *wqe_ptrs[1]; + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + wqe_ptrs[0] = ibgda_get_wqe_ptr(qp, my_wqe_idx); + + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), + qp->ibuf.lkey, raddr, rkey, my_wqe_idx, wqe_ptrs); + + ibgda_submit_requests(qp, my_wqe_idx, 1); +} + } // namespace deep_ep diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 426c7bc..76ae2e2 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, - dst_rank, dst_expert_local_idx, 0); - nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); } @@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, int num_recv_tokens, recv_token_begin_idx; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); if (sub_warp_id == 1 and lane_id == 0) { - if (src_rank != rank) { - nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx); - num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank); - EP_DEVICE_ASSERT(num_recv_tokens != 0); - } else { - while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); - } + while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; @@ -439,7 +431,7 @@ combine(void* combined_x, if (sub_warp_id == 1 and lane_id == 0) { while (ld_acquire_global(atomic_clean_flag) == 0); if (dst_rank != rank) { - nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); } else { st_na_release(rdma_recv_flag + global_expert_idx, 1); } @@ -456,16 +448,8 @@ combine(void* combined_x, // Wait all ranks to arrive and notify PCIe usage if (responsible_expert_idx < num_experts) { EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); - if (sub_warp_id == 0 and lane_id == 0) { - // TODO: refactor QP indices - auto src_rank = responsible_expert_idx / num_local_experts; - auto src_expert_idx = responsible_expert_idx % num_local_experts; - if (src_rank != rank) { - nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx); - } else { - while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0); - } - } + if (sub_warp_id == 0 and lane_id == 0) + while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0); } cg::this_grid().sync(); diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 9519c03..c9f5879 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -41,27 +41,6 @@ std::vector get_unique_id() { return result; } -__global__ void ibgda_initialize_recv_queue(int rank) { - auto thread_idx = static_cast(threadIdx.x); - auto num_threads = static_cast(blockDim.x); - - auto dst_rank = static_cast(blockIdx.x); - if (dst_rank != rank) { - for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) { - auto qp = ibgda_get_rc(dst_rank, qp_id); - - // Clean some necessary variables - for (int i = 0; i < qp->rx_wq.nwqes; ++ i) - ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i)); - qp->mvars.rx_wq.resv_head = 0; - qp->mvars.rx_wq.cons_idx = 0; - - // Allocate receive slots - nvshmemi_ibgda_allocate_recvs(qp); - } - } -} - int init(const std::vector &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) { nvshmemx_uniqueid_t root_unique_id; nvshmemx_init_attr_t attr; @@ -85,10 +64,7 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); bool ibgda_is_initialized = false; - cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice); - - // Initialize recv queues for low-latency mode AR - ibgda_initialize_recv_queue<<>>(rank); + CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); } nvshmem_barrier_all(); return nvshmem_my_pe();