From 5ab80c28f3d6c3e4f88ce236f427ab7c81025172 Mon Sep 17 00:00:00 2001 From: moningchen Date: Mon, 21 Apr 2025 15:37:19 +0800 Subject: [PATCH 1/8] In the Internode Normal Kernel, when using nvshmem ibrc for RDMA data transmission, a single QP is used for data transfer between two GPUs, which limits kernel performance in network card dual-port and RoCE network scenarios. In our optimized Internode Normal Kernel, we implemented multiple QPs for data transmission between two GPUs, setting a different QP for each channel. Additionally, we modified the transmission method from IBRC to IBGAD. Through these optimizations, the Internode Normal Kernel achieves optimal performance in both H800 and H20 environments, with RDMA transmission performance nearly reaching the physical network performance limit. Using the current default statistical method, in 4-node H800 and H20 environments, RDMA performance can reach 60GB/s+. --- csrc/kernels/ibgda_device.cuh | 351 +++++++++++++++++++++++++++++++++- csrc/kernels/internode.cu | 74 +++++-- csrc/kernels/runtime.cu | 3 +- deep_ep/buffer.py | 3 +- tests/test_internode.py | 6 +- 5 files changed, 414 insertions(+), 23 deletions(-) diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 1e8d871..ae71c5e 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -11,6 +11,10 @@ #include "exception.cuh" #include "utils.cuh" +// #define NVSHMEM_TIMEOUT_DEVICE_POLLING +// #define IBGDA_POLL_TIMEOUT 4000000000LLU +// #define NVSHMEM_IBGDA_DEBUG + namespace deep_ep { EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); @@ -242,15 +246,353 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) *out_rkey = device_key.key; } +#ifndef likely +#define likely(x) (__builtin_expect(!!(x), 1)) +#endif + +#ifndef unlikely +#define unlikely(x) (__builtin_expect(!!(x), 0)) +#endif + +#ifndef ACCESS_ONCE +#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x)) +#endif + +/** + * DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug. + * BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE. + */ +#ifndef READ_ONCE +#define READ_ONCE(x) ACCESS_ONCE(x) +#endif + +#ifndef WRITE_ONCE +#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v)) +#endif + +#ifdef NVSHMEM_IBGDA_DEBUG +struct mlx5_err_cqe_ex { + uint8_t rsvd0[32]; + __be32 srqn; + uint8_t rsvd1[16]; + uint8_t hw_err_synd; + uint8_t hw_synd_type; + uint8_t vendor_err_synd; + uint8_t syndrome; + __be32 s_wqe_opcode_qpn; + __be16 wqe_counter; + uint8_t signature; + uint8_t op_own; +}; +typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t; +#else +typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t; +#endif + +__device__ static inline uint16_t BSWAP16(uint16_t x) { + uint16_t ret; + + uint32_t a = (uint32_t)x; + uint32_t d; + asm volatile( + "{\n\t" + ".reg .b32 mask;\n\t" + ".reg .b32 ign;\n\t" + "mov.b32 mask, 0x4401;\n\t" + "mov.b32 ign, 0x0;\n\t" + "prmt.b32 %0, %1, ign, mask;\n\t" + "}" + : "=r"(d) + : "r"(a)); + ret = (uint16_t)d; + return ret; +} + +/** + * DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug. + * See the comment near READ_ONCE. + */ +__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return (uint8_t)ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint32_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint64_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +// Prevent code reordering from both compiler and GPU +__device__ static inline void IBGDA_MFENCE() { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE + asm volatile("fence.acq_rel.cta;" ::: "memory"); +#else + __threadfence_block(); +#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */ +} + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +__device__ static inline uint64_t ibgda_query_globaltimer() { + uint64_t ret; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory"); + return ret; +} +#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now, + uint64_t start, uint64_t idx, int *error) { + int status = 0; + + struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; + uint8_t opown; + uint8_t opcode; + uint16_t wqe_counter; + + if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) { + *error = -ETIME; + + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + + wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); + wqe_counter = BSWAP16(wqe_counter); + + printf( + "[%d] ibgda_poll_cq timeout:\n" + " cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n" + " wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n" + " while waiting for idx=%#lx.\n", + nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx), + ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter, + ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx); + status = -1; + } + return status; +} +#endif + +__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx, + int *error) { + int status = 0; + struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; + + const uint32_t ncqes = cq->ncqes; + + uint8_t opown; + uint8_t opcode; + uint16_t wqe_counter; + uint16_t new_wqe_counter; + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + uint64_t start = ibgda_query_globaltimer(); + uint64_t now; +#endif + + uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx); + uint64_t new_cons_idx; + + assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI || + cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC)); + + if (unlikely(cons_idx >= idx)) goto out; + +#ifdef NVSHMEM_IBGDA_DEBUG + // We can skip opcode == MLX5_CQE_INVALID check because we have already + // initialized the CQ buffer to 0xff. With the QP depth range we enforce, + // cons_idx cannot progress unless wqe_counter read from the CQ buffer is + // a valid value. + do { + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + // TODO: Integrate timeout handler with the core NVSHMEM + now = ibgda_query_globaltimer(); + status = ibgda_check_poll_timeout(cq, now, start, idx, error); + if (status != 0) goto check_opcode; +#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ + } while (unlikely(opcode == MLX5_CQE_INVALID)); + + // Prevent reordering of the opcode wait above + IBGDA_MFENCE(); +#endif /* NVSHMEM_IBGDA_DEBUG */ + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + start = ibgda_query_globaltimer(); +#endif + + // If idx is a lot greater than cons_idx, we might get incorrect result due + // to wqe_counter wraparound. We need to check prod_idx to be sure that idx + // has already been submitted. + while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx)) + ; + IBGDA_MFENCE(); + + do { + new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); + new_wqe_counter = BSWAP16(new_wqe_counter); +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + now = ibgda_query_globaltimer(); + status = ibgda_check_poll_timeout(cq, now, start, idx, error); + if (status != 0) goto check_opcode; + + // Observe progress. Reset the timer. + if (new_wqe_counter != wqe_counter) start = now; +#endif + wqe_counter = new_wqe_counter; + + // Another thread may have updated cons_idx. + cons_idx = ibgda_atomic_read(cq->cons_idx); + if (likely(cons_idx >= idx)) goto out; + } + // NOTE: This while loop is part of do while above. + // wqe_counter is the HW consumer index. However, we always maintain index + // + 1 in SW. To be able to compare with idx, we need to use wqe_counter + + // 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for + // sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than + // idx, and thus we need to wait. We don't need to wait when idx == + // wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case + // wraparound. + while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes))); + + // new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the + // MSB from idx. We also need to take care of wraparound. + ++wqe_counter; + new_cons_idx = + (idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0); + atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx); + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +check_opcode: +#endif + + // NVSHMEM always treats CQE errors as fatal. + // Even if this error doesn't belong to the CQE in cons_idx, + // we will just report and terminate the process. + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + + if (unlikely(opcode == MLX5_CQE_REQ_ERR)) { + ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64; + *error = cqe_err->syndrome; +#ifdef NVSHMEM_IBGDA_DEBUG + __be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter); + __be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn); + printf( + "[%d] got completion with err:\n" + " syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n" + " wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n" + " cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n", + nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd, + cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter), + BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx); +#endif /* NVSHMEM_IBGDA_DEBUG */ + status = -1; + } + +out: + // Prevent reordering of this function and subsequent instructions + IBGDA_MFENCE(); + + return status; +} + +__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) { + nvshmemi_ibgda_device_state_t *state = ibgda_get_state(); + uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx) + : ibgda_atomic_read(&qp->mvars.tx_wq.ready_head); + nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; + + int err = 0; + int status = ibgda_poll_cq(&cq, prod_idx, &err); + // TODO: Integrate the error handler with the core NVSHMEM +#ifdef NVSHMEM_IBGDA_DEBUG + if (status) { + printf("ibgda_poll_cq failed with error=%d.\n", err); + } +#endif + assert(likely(status == 0)); + return prod_idx; +} + +__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) { + int status = 0; + int err = 0; + uint16_t nwqes = qp->tx_wq.nwqes; + nwqes = nwqes / 2; + + // We don't want wqe_idx - nwqes to wraparound. + if (likely(wqe_idx >= nwqes)) { + nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; + status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err); + // TODO: Integrate the error handler with the core NVSHMEM + if (status) { + printf("ibgda_poll_cq failed with error=%d.\n", err); + } + assert(likely(status == 0)); + } + IBGDA_MFENCE(); +} + +template __device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { auto mvars = &qp->mvars; - return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + uint64_t wqe_idx; + wqe_idx = atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + if (!nbi) { + uint64_t prod_idx = mvars->tx_wq.prod_idx; + uint64_t cons_idx = mvars->tx_wq.cons_idx; + uint64_t delta = prod_idx - cons_idx; + uint64_t cnt = qp->tx_wq.nwqes; + if (delta > cnt) { + printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta); + EP_DEVICE_ASSERT(delta <= cnt); + } + + // If last slot is available, all prior slots are also available. + ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes); + } + + // return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + return wqe_idx; } __device__ static __forceinline__ void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { uint16_t cnt = qp->tx_wq.nwqes; + EP_DEVICE_ASSERT(cnt != 0); uint16_t idx = wqe_idx & (cnt - 1); return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); } @@ -325,6 +667,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -354,7 +697,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, auto qp = ibgda_get_rc(dst_pe, qp_id); uint64_t base_wqe_idx = 0; if (lane_id == 0) - base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); if (lane_id < num_wqes) { auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id); @@ -367,6 +710,10 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, if (lane_id == 0) ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); + + // if (!nbi) { + // ibgda_quiet(qp); + // } } __device__ static __forceinline__ void ibgda_write_amo_add_wqe( diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index d6ad583..306dfe0 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -3,6 +3,7 @@ #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" +#include "ibgda_device.cuh" namespace deep_ep { @@ -710,10 +711,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - num_bytes_per_rdma_token * num_tokens_to_issue, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, + // rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, + // num_bytes_per_rdma_token * num_tokens_to_issue, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t); + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { // Lighter fence for local RDMA rank @@ -725,8 +731,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } } } } @@ -926,8 +939,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(lane_id, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(lane_id, nvl_rank)); + if (lane_id != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(lane_id, nvl_rank)); + } last_head = min_head; } @@ -1558,10 +1578,15 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - num_chunked_tokens * num_bytes_per_rdma_token, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, + // rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, + // num_chunked_tokens * num_bytes_per_rdma_token, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t); + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { memory_fence(); @@ -1569,9 +1594,17 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail __syncwarp(); - if (lane_id == 0) - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (lane_id == 0) { + // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } + } } } @@ -1656,8 +1689,15 @@ combine(int4* combined_x, float* combined_topk_weights, for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } last_rdma_head = min_head; } } else { diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c9f5879..d008098 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -59,7 +59,8 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks } // Normal operations use IBRC, while low-latency operations use IBGDA - if (low_latency_mode) { + bool internode_use_ibgda = true; + if (low_latency_mode or internode_use_ibgda) { nvshmemi_device_host_state_t* dev_state_ptr = nullptr; CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 831a2e6..0665eb7 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -65,9 +65,10 @@ class Buffer: # Synchronize NVSHMEM unique IDs root_unique_id = None + internode_use_ibgda = True if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode: + if low_latency_mode or internode_use_ibgda: assert num_qps_per_rank > 0 os.environ['NVSHMEM_DISABLE_P2P'] = '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' diff --git a/tests/test_internode.py b/tests/test_internode.py index 5884a16..456ea64 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -218,17 +218,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in # noinspection PyUnboundLocalVariable def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) + num_sms = 24 + qp_num = num_sms // 2 rank, num_ranks, group = init_dist(local_rank, num_local_ranks) test_ll_compatibility = False if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num)) assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank) - for i in (24, ): + for i in (num_sms, ): test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) if local_rank == 0: print('', flush=True) From e2c578485cb36dc9e08958cf3213525a6ffd7d05 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Mon, 21 Apr 2025 17:44:32 +0800 Subject: [PATCH 2/8] Revert `ibgda_device.cuh` and remove some comments. --- csrc/kernels/ibgda_device.cuh | 351 +--------------------------------- csrc/kernels/internode.cu | 18 +- 2 files changed, 4 insertions(+), 365 deletions(-) diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index ae71c5e..1e8d871 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -11,10 +11,6 @@ #include "exception.cuh" #include "utils.cuh" -// #define NVSHMEM_TIMEOUT_DEVICE_POLLING -// #define IBGDA_POLL_TIMEOUT 4000000000LLU -// #define NVSHMEM_IBGDA_DEBUG - namespace deep_ep { EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); @@ -246,353 +242,15 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) *out_rkey = device_key.key; } -#ifndef likely -#define likely(x) (__builtin_expect(!!(x), 1)) -#endif - -#ifndef unlikely -#define unlikely(x) (__builtin_expect(!!(x), 0)) -#endif - -#ifndef ACCESS_ONCE -#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x)) -#endif - -/** - * DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug. - * BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE. - */ -#ifndef READ_ONCE -#define READ_ONCE(x) ACCESS_ONCE(x) -#endif - -#ifndef WRITE_ONCE -#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v)) -#endif - -#ifdef NVSHMEM_IBGDA_DEBUG -struct mlx5_err_cqe_ex { - uint8_t rsvd0[32]; - __be32 srqn; - uint8_t rsvd1[16]; - uint8_t hw_err_synd; - uint8_t hw_synd_type; - uint8_t vendor_err_synd; - uint8_t syndrome; - __be32 s_wqe_opcode_qpn; - __be16 wqe_counter; - uint8_t signature; - uint8_t op_own; -}; -typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t; -#else -typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t; -#endif - -__device__ static inline uint16_t BSWAP16(uint16_t x) { - uint16_t ret; - - uint32_t a = (uint32_t)x; - uint32_t d; - asm volatile( - "{\n\t" - ".reg .b32 mask;\n\t" - ".reg .b32 ign;\n\t" - "mov.b32 mask, 0x4401;\n\t" - "mov.b32 ign, 0x0;\n\t" - "prmt.b32 %0, %1, ign, mask;\n\t" - "}" - : "=r"(d) - : "r"(a)); - ret = (uint16_t)d; - return ret; -} - -/** - * DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug. - * See the comment near READ_ONCE. - */ -__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) { -#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET - uint16_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); - return (uint8_t)ret; -#else - return READ_ONCE(*ptr); -#endif -} - -__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) { -#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET - uint16_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); - return ret; -#else - return READ_ONCE(*ptr); -#endif -} - -__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) { -#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET - uint32_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); - return ret; -#else - return READ_ONCE(*ptr); -#endif -} - -__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) { -#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET - uint64_t ret; - asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); - return ret; -#else - return READ_ONCE(*ptr); -#endif -} - -// Prevent code reordering from both compiler and GPU -__device__ static inline void IBGDA_MFENCE() { -#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE - asm volatile("fence.acq_rel.cta;" ::: "memory"); -#else - __threadfence_block(); -#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */ -} - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING -__device__ static inline uint64_t ibgda_query_globaltimer() { - uint64_t ret; - asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory"); - return ret; -} -#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING -__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now, - uint64_t start, uint64_t idx, int *error) { - int status = 0; - - struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; - uint8_t opown; - uint8_t opcode; - uint16_t wqe_counter; - - if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) { - *error = -ETIME; - - opown = ibgda_atomic_read(&cqe64->op_own); - opcode = opown >> 4; - - wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); - wqe_counter = BSWAP16(wqe_counter); - - printf( - "[%d] ibgda_poll_cq timeout:\n" - " cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n" - " wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n" - " while waiting for idx=%#lx.\n", - nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx), - ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter, - ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx); - status = -1; - } - return status; -} -#endif - -__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx, - int *error) { - int status = 0; - struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; - - const uint32_t ncqes = cq->ncqes; - - uint8_t opown; - uint8_t opcode; - uint16_t wqe_counter; - uint16_t new_wqe_counter; - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING - uint64_t start = ibgda_query_globaltimer(); - uint64_t now; -#endif - - uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx); - uint64_t new_cons_idx; - - assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI || - cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC)); - - if (unlikely(cons_idx >= idx)) goto out; - -#ifdef NVSHMEM_IBGDA_DEBUG - // We can skip opcode == MLX5_CQE_INVALID check because we have already - // initialized the CQ buffer to 0xff. With the QP depth range we enforce, - // cons_idx cannot progress unless wqe_counter read from the CQ buffer is - // a valid value. - do { - opown = ibgda_atomic_read(&cqe64->op_own); - opcode = opown >> 4; - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING - // TODO: Integrate timeout handler with the core NVSHMEM - now = ibgda_query_globaltimer(); - status = ibgda_check_poll_timeout(cq, now, start, idx, error); - if (status != 0) goto check_opcode; -#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ - } while (unlikely(opcode == MLX5_CQE_INVALID)); - - // Prevent reordering of the opcode wait above - IBGDA_MFENCE(); -#endif /* NVSHMEM_IBGDA_DEBUG */ - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING - start = ibgda_query_globaltimer(); -#endif - - // If idx is a lot greater than cons_idx, we might get incorrect result due - // to wqe_counter wraparound. We need to check prod_idx to be sure that idx - // has already been submitted. - while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx)) - ; - IBGDA_MFENCE(); - - do { - new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); - new_wqe_counter = BSWAP16(new_wqe_counter); -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING - now = ibgda_query_globaltimer(); - status = ibgda_check_poll_timeout(cq, now, start, idx, error); - if (status != 0) goto check_opcode; - - // Observe progress. Reset the timer. - if (new_wqe_counter != wqe_counter) start = now; -#endif - wqe_counter = new_wqe_counter; - - // Another thread may have updated cons_idx. - cons_idx = ibgda_atomic_read(cq->cons_idx); - if (likely(cons_idx >= idx)) goto out; - } - // NOTE: This while loop is part of do while above. - // wqe_counter is the HW consumer index. However, we always maintain index - // + 1 in SW. To be able to compare with idx, we need to use wqe_counter + - // 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for - // sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than - // idx, and thus we need to wait. We don't need to wait when idx == - // wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case - // wraparound. - while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes))); - - // new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the - // MSB from idx. We also need to take care of wraparound. - ++wqe_counter; - new_cons_idx = - (idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0); - atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx); - -#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING -check_opcode: -#endif - - // NVSHMEM always treats CQE errors as fatal. - // Even if this error doesn't belong to the CQE in cons_idx, - // we will just report and terminate the process. - opown = ibgda_atomic_read(&cqe64->op_own); - opcode = opown >> 4; - - if (unlikely(opcode == MLX5_CQE_REQ_ERR)) { - ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64; - *error = cqe_err->syndrome; -#ifdef NVSHMEM_IBGDA_DEBUG - __be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter); - __be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn); - printf( - "[%d] got completion with err:\n" - " syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n" - " wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n" - " cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n", - nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd, - cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter), - BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx); -#endif /* NVSHMEM_IBGDA_DEBUG */ - status = -1; - } - -out: - // Prevent reordering of this function and subsequent instructions - IBGDA_MFENCE(); - - return status; -} - -__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) { - nvshmemi_ibgda_device_state_t *state = ibgda_get_state(); - uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx) - : ibgda_atomic_read(&qp->mvars.tx_wq.ready_head); - nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; - - int err = 0; - int status = ibgda_poll_cq(&cq, prod_idx, &err); - // TODO: Integrate the error handler with the core NVSHMEM -#ifdef NVSHMEM_IBGDA_DEBUG - if (status) { - printf("ibgda_poll_cq failed with error=%d.\n", err); - } -#endif - assert(likely(status == 0)); - return prod_idx; -} - -__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) { - int status = 0; - int err = 0; - uint16_t nwqes = qp->tx_wq.nwqes; - nwqes = nwqes / 2; - - // We don't want wqe_idx - nwqes to wraparound. - if (likely(wqe_idx >= nwqes)) { - nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; - status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err); - // TODO: Integrate the error handler with the core NVSHMEM - if (status) { - printf("ibgda_poll_cq failed with error=%d.\n", err); - } - assert(likely(status == 0)); - } - IBGDA_MFENCE(); -} - -template __device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { auto mvars = &qp->mvars; - uint64_t wqe_idx; - wqe_idx = atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); - if (!nbi) { - uint64_t prod_idx = mvars->tx_wq.prod_idx; - uint64_t cons_idx = mvars->tx_wq.cons_idx; - uint64_t delta = prod_idx - cons_idx; - uint64_t cnt = qp->tx_wq.nwqes; - if (delta > cnt) { - printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta); - EP_DEVICE_ASSERT(delta <= cnt); - } - - // If last slot is available, all prior slots are also available. - ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes); - } - - // return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); - return wqe_idx; + return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); } __device__ static __forceinline__ void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { uint16_t cnt = qp->tx_wq.nwqes; - EP_DEVICE_ASSERT(cnt != 0); uint16_t idx = wqe_idx & (cnt - 1); return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); } @@ -667,7 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } -template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -697,7 +354,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, auto qp = ibgda_get_rc(dst_pe, qp_id); uint64_t base_wqe_idx = 0; if (lane_id == 0) - base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); if (lane_id < num_wqes) { auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id); @@ -710,10 +367,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, if (lane_id == 0) ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); - - // if (!nbi) { - // ibgda_quiet(qp); - // } } __device__ static __forceinline__ void ibgda_write_amo_add_wqe( diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 306dfe0..5fcf7a3 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -711,14 +711,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - // rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - // num_bytes_per_rdma_token * num_tokens_to_issue, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t); const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { @@ -731,8 +727,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); if (dst_rdma_rank != rdma_rank) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); @@ -939,8 +933,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - // translate_dst_rdma_rank(lane_id, nvl_rank)); if (lane_id != rdma_rank) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, translate_dst_rdma_rank(lane_id, nvl_rank), channel_id); @@ -1578,14 +1570,10 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - // rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - // num_chunked_tokens * num_bytes_per_rdma_token, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t); const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { @@ -1595,8 +1583,6 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail __syncwarp(); if (lane_id == 0) { - // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); if (dst_rdma_rank != rdma_rank) { nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); From e0eaaf94fbedd700c83d3a635290ba49047c1404 Mon Sep 17 00:00:00 2001 From: moningchen Date: Mon, 21 Apr 2025 21:30:08 +0800 Subject: [PATCH 3/8] Add the performance data after internode optimization in the Readme file --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 435ccea..fbf3af2 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,16 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c | Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) | | Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) | +Through in-depth optimization, the following enhancements have been implemented in the Internode Normal Kernel: 1) Replacing IBRC with IBGDA, and 2) Utilizing distinct QPs (Queue Pairs) per channel for parallel data transmission. These improvements not only enhance the robustness of the Internode Normal Kernel in scenarios involving dual-port NICs and RoCE networks but also further elevate communication performance. + +| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | +|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| +| Internode | 16 | 47 GB/s (RDMA) | 16 | 62 GB/s (RDMA) | +| Internode | 32 | 59 GB/s (RDMA) | 32 | 60 GB/s (RDMA) | +| Internode | 64 | 49 GB/s (RDMA) | 64 | 51 GB/s (RDMA) | + +The performance optimization solution for Internode Normal Kernel was jointly completed by our team and Tencent Network Platform Department. + ### Low-latency kernels with pure RDMA We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining). From 20b2aaaf9e47ddeb910e1a31affb19ffc2cb6022 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 22 Apr 2025 10:22:30 +0800 Subject: [PATCH 4/8] Refactor some code. --- csrc/kernels/ibgda_device.cuh | 60 ++++++++++++++++++++++----- csrc/kernels/internode.cu | 78 +++++++++++++++-------------------- csrc/kernels/internode_ll.cu | 4 +- tests/test_internode.py | 9 ++-- 4 files changed, 90 insertions(+), 61 deletions(-) diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 1e8d871..eebdc59 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -410,20 +410,60 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe( st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } -__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) { - nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); +__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { + if (is_local_copy) { + // Fallback to NVSHMEM legacy API + nvshmemx_signal_op(reinterpret_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); + } else { + nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); - __be32 rkey; - uint64_t raddr; - ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); - uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); - void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); - ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), - qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), + qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); - ibgda_submit_requests(qp, my_wqe_idx, 1); + ibgda_submit_requests(qp, my_wqe_idx, 1); + } +} + +__device__ static __forceinline__ void +nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, + int dst_pe, int qp_id, bool is_local_copy) { + if (is_local_copy) { + // Fallback to NVSHMEM legacy API + // TODO: rewrite local API copy with unrolling and vectorization + nvshmem_uint8_put_nbi(reinterpret_cast(req_rptr), reinterpret_cast(req_lptr), bytes, dst_pe); + } else { + uint32_t num_wqes = 0; + uint64_t base_wqe_idx = 0; + + auto qp = ibgda_get_rc(dst_pe, qp_id); + while (bytes > 0) { + __be32 lkey, rkey; + uint64_t laddr, raddr, chunk_size; + + chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey)); + bytes -= chunk_size; + + auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); + + // Only the last WQE should send imm + ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr); + + req_lptr += chunk_size; + req_rptr += chunk_size; + if ((num_wqes ++) == 0) + base_wqe_idx = wqe_idx; + } + + ibgda_submit_requests(qp, base_wqe_idx, num_wqes); + } } } // namespace deep_ep diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 5fcf7a3..0d29aec 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -480,6 +480,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels); + const auto role_meta = [=]() -> std::pair { if (is_forwarder) { if (warp_id < NUM_MAX_NVL_PEERS) { @@ -556,19 +558,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); if (lane_id < NUM_MAX_NVL_PEERS) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { - rdma_channel_meta.send_buffer(dst_rdma_rank)[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank and lane_id == 0) { + nvshmemi_ibgda_put_nbi_thread(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, false); } - nvshmemx_int_put_nbi_warp(rdma_channel_meta.recv_buffer(rdma_rank), rdma_channel_meta.send_buffer(dst_rdma_rank), NUM_MAX_NVL_PEERS * 2 + 2, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); } - nvshmem_fence(); sync_rdma_sender_smem(); // Iterate over tokens and copy into buffer @@ -711,12 +721,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t); + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); - nvshmem_fence(); + if (lane_id == dst_rdma_rank) + nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -727,13 +737,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } @@ -933,13 +938,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - if (lane_id != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, - translate_dst_rdma_rank(lane_id, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(lane_id, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id, lane_id == rdma_rank); last_head = min_head; } @@ -1570,12 +1570,12 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t); + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); - nvshmem_fence(); + if (lane_id == 0) + nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); } else { memory_fence(); } @@ -1583,13 +1583,8 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail __syncwarp(); if (lane_id == 0) { - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); } } } @@ -1675,15 +1670,8 @@ combine(int4* combined_x, float* combined_topk_weights, for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - if (dst_rdma_rank != rdma_rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); - } else { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); - } + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } } else { diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index c33e062..74fe0bd 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -167,7 +167,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for checking QPs - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_local_experts); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); // The first SM is also responsible for cleaning the next buffer #pragma unroll @@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); } diff --git a/tests/test_internode.py b/tests/test_internode.py index 456ea64..e9b3d57 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -218,15 +218,16 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in # noinspection PyUnboundLocalVariable def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) - num_sms = 24 - qp_num = num_sms // 2 rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - test_ll_compatibility = False + test_ll_compatibility = True if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + num_sms = 24 + num_qps_per_rank = max(num_sms // 2, ll_num_experts // num_ranks if test_ll_compatibility else 0) + buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num)) + num_qps_per_rank=num_qps_per_rank) assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank) From 3e54b78fd776108d04e959e6988e96f62d0314d8 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 22 Apr 2025 10:36:24 +0800 Subject: [PATCH 5/8] Normal kernels always use IBGDA mode. --- csrc/kernels/runtime.cu | 13 +++++-------- deep_ep/buffer.py | 22 ++++++++++------------ 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index d008098..8e536ca 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -58,15 +58,12 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - // Normal operations use IBRC, while low-latency operations use IBGDA - bool internode_use_ibgda = true; - if (low_latency_mode or internode_use_ibgda) { - nvshmemi_device_host_state_t* dev_state_ptr = nullptr; - CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); + // TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later + nvshmemi_device_host_state_t* dev_state_ptr = nullptr; + CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); - bool ibgda_is_initialized = false; - CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); - } + bool ibgda_is_initialized = false; + CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice)); nvshmem_barrier_all(); return nvshmem_my_pe(); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 0665eb7..dd14838 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -65,19 +65,17 @@ class Buffer: # Synchronize NVSHMEM unique IDs root_unique_id = None - internode_use_ibgda = True if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode or internode_use_ibgda: - assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' - # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - os.environ['NVSHMEM_QP_DEPTH'] = '1024' - # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + # Enable IBGDA + assert num_qps_per_rank > 0 + os.environ['NVSHMEM_DISABLE_P2P'] = '1' + os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' + os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' + os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' + # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check + os.environ['NVSHMEM_QP_DEPTH'] = '1024' + # NOTES: NVSHMEM initialization requires at least 256 MiB + os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' # Synchronize using the root ID nvshmem_unique_ids = [None, ] * self.group_size From edbb1bc3ffc73d33a612ad254b6b75fc904ba222 Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Tue, 22 Apr 2025 10:52:10 +0800 Subject: [PATCH 6/8] Several code lints --- README.md | 10 +--------- csrc/kernels/ibgda_device.cuh | 2 +- csrc/kernels/internode.cu | 19 +++++++++++-------- deep_ep/buffer.py | 2 +- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index fbf3af2..4a9367d 100644 --- a/README.md +++ b/README.md @@ -17,19 +17,11 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c | Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| | Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | -| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | -| Internode | 32 | 44 GB/s (RDMA) | 32 | 47 GB/s (RDMA) | -| Internode | 64 | 46 GB/s (RDMA) | 64 | 45 GB/s (RDMA) | - -Through in-depth optimization, the following enhancements have been implemented in the Internode Normal Kernel: 1) Replacing IBRC with IBGDA, and 2) Utilizing distinct QPs (Queue Pairs) per channel for parallel data transmission. These improvements not only enhance the robustness of the Internode Normal Kernel in scenarios involving dual-port NICs and RoCE networks but also further elevate communication performance. - -| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | -|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| | Internode | 16 | 47 GB/s (RDMA) | 16 | 62 GB/s (RDMA) | | Internode | 32 | 59 GB/s (RDMA) | 32 | 60 GB/s (RDMA) | | Internode | 64 | 49 GB/s (RDMA) | 64 | 51 GB/s (RDMA) | -The performance optimization solution for Internode Normal Kernel was jointly completed by our team and Tencent Network Platform Department. +**News (2025.04.22)**: the performance is optimized by 5-35% by Tencent Network Platform Department, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! ### Low-latency kernels with pure RDMA diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index eebdc59..2200a07 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -413,7 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe( __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { if (is_local_copy) { // Fallback to NVSHMEM legacy API - nvshmemx_signal_op(reinterpret_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); + nvshmemx_signal_op(static_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); } else { nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 0d29aec..0b59d1f 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -573,10 +573,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Issue RDMA for non-local ranks if (dst_rdma_rank != rdma_rank and lane_id == 0) { nvshmemi_ibgda_put_nbi_thread(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), - reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), - sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), - channel_id, false); + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, false); } } sync_rdma_sender_smem(); @@ -724,9 +724,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - if (lane_id == dst_rdma_rank) + if (lane_id == dst_rdma_rank) { nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); + } } else { // Lighter fence for local RDMA rank memory_fence(); @@ -1573,9 +1574,11 @@ combine(int4* combined_x, float* combined_topk_weights, const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - if (lane_id == 0) + if (lane_id == 0) { + // TODO: use the full warp to do this nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); + } } else { memory_fence(); } diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index dd14838..feeb386 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -31,7 +31,7 @@ class Buffer: def __init__(self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, - low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None: + low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None: """ Initialize the communication buffer. From 3b1045db43ed232691d32dde0517dfed571b0a65 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 22 Apr 2025 11:23:42 +0800 Subject: [PATCH 7/8] Fix the performance data. --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4a9367d..c9f9552 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,11 @@ We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each c | Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | |:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| | Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | -| Internode | 16 | 47 GB/s (RDMA) | 16 | 62 GB/s (RDMA) | -| Internode | 32 | 59 GB/s (RDMA) | 32 | 60 GB/s (RDMA) | -| Internode | 64 | 49 GB/s (RDMA) | 64 | 51 GB/s (RDMA) | +| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | +| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) | +| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) | -**News (2025.04.22)**: the performance is optimized by 5-35% by Tencent Network Platform Department, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! +**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! ### Low-latency kernels with pure RDMA From e255d57befba7452f4677c31677e03eb6ad96711 Mon Sep 17 00:00:00 2001 From: Shangyan Zhou Date: Tue, 22 Apr 2025 12:29:46 +0800 Subject: [PATCH 8/8] Use `put_nbi_warp`. --- csrc/kernels/ibgda_device.cuh | 38 ++--------------------------------- csrc/kernels/internode.cu | 25 +++++++++-------------- csrc/kernels/internode_ll.cu | 2 +- 3 files changed, 13 insertions(+), 52 deletions(-) diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 2200a07..9f8c37c 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -325,6 +325,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -365,7 +366,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, // Submit if (lane_id == 0) - ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); + ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); } @@ -431,39 +432,4 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons } } -__device__ static __forceinline__ void -nvshmemi_ibgda_put_nbi_thread(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, - int dst_pe, int qp_id, bool is_local_copy) { - if (is_local_copy) { - // Fallback to NVSHMEM legacy API - // TODO: rewrite local API copy with unrolling and vectorization - nvshmem_uint8_put_nbi(reinterpret_cast(req_rptr), reinterpret_cast(req_lptr), bytes, dst_pe); - } else { - uint32_t num_wqes = 0; - uint64_t base_wqe_idx = 0; - - auto qp = ibgda_get_rc(dst_pe, qp_id); - while (bytes > 0) { - __be32 lkey, rkey; - uint64_t laddr, raddr, chunk_size; - - chunk_size = min(bytes, ibgda_get_lkey_and_rkey(laddr = req_lptr, &lkey, req_rptr, dst_pe, &raddr, &rkey)); - bytes -= chunk_size; - - auto wqe_idx = ibgda_reserve_wqe_slots(qp, 1); - auto wqe_ptr = ibgda_get_wqe_ptr(qp, wqe_idx); - - // Only the last WQE should send imm - ibgda_write_rdma_write_wqe(qp, laddr, lkey, raddr, rkey, chunk_size, wqe_idx,&wqe_ptr); - - req_lptr += chunk_size; - req_rptr += chunk_size; - if ((num_wqes ++) == 0) - base_wqe_idx = wqe_idx; - } - - ibgda_submit_requests(qp, base_wqe_idx, num_wqes); - } -} - } // namespace deep_ep diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 0b59d1f..2e77460 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -571,12 +571,12 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv __syncwarp(); // Issue RDMA for non-local ranks - if (dst_rdma_rank != rdma_rank and lane_id == 0) { - nvshmemi_ibgda_put_nbi_thread(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), - reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), - sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), - channel_id, false); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, lane_id, 0); } } sync_rdma_sender_smem(); @@ -724,10 +724,8 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); - if (lane_id == dst_rdma_rank) { - nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); - } + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -1574,11 +1572,8 @@ combine(int4* combined_x, float* combined_topk_weights, const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); - if (lane_id == 0) { - // TODO: use the full warp to do this - nvshmemi_ibgda_put_nbi_thread(dst_ptr, src_ptr, num_bytes_per_msg, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, false); - } + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); } else { memory_fence(); } diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 74fe0bd..8e0d9e4 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -215,7 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { - nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, false); + nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); }