diff --git a/csrc/kernels/ibgda_device.cuh b/csrc/kernels/ibgda_device.cuh index 1e8d871..ae71c5e 100644 --- a/csrc/kernels/ibgda_device.cuh +++ b/csrc/kernels/ibgda_device.cuh @@ -11,6 +11,10 @@ #include "exception.cuh" #include "utils.cuh" +// #define NVSHMEM_TIMEOUT_DEVICE_POLLING +// #define IBGDA_POLL_TIMEOUT 4000000000LLU +// #define NVSHMEM_IBGDA_DEBUG + namespace deep_ep { EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); @@ -242,15 +246,353 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) *out_rkey = device_key.key; } +#ifndef likely +#define likely(x) (__builtin_expect(!!(x), 1)) +#endif + +#ifndef unlikely +#define unlikely(x) (__builtin_expect(!!(x), 0)) +#endif + +#ifndef ACCESS_ONCE +#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x)) +#endif + +/** + * DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug. + * BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE. + */ +#ifndef READ_ONCE +#define READ_ONCE(x) ACCESS_ONCE(x) +#endif + +#ifndef WRITE_ONCE +#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v)) +#endif + +#ifdef NVSHMEM_IBGDA_DEBUG +struct mlx5_err_cqe_ex { + uint8_t rsvd0[32]; + __be32 srqn; + uint8_t rsvd1[16]; + uint8_t hw_err_synd; + uint8_t hw_synd_type; + uint8_t vendor_err_synd; + uint8_t syndrome; + __be32 s_wqe_opcode_qpn; + __be16 wqe_counter; + uint8_t signature; + uint8_t op_own; +}; +typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t; +#else +typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t; +#endif + +__device__ static inline uint16_t BSWAP16(uint16_t x) { + uint16_t ret; + + uint32_t a = (uint32_t)x; + uint32_t d; + asm volatile( + "{\n\t" + ".reg .b32 mask;\n\t" + ".reg .b32 ign;\n\t" + "mov.b32 mask, 0x4401;\n\t" + "mov.b32 ign, 0x0;\n\t" + "prmt.b32 %0, %1, ign, mask;\n\t" + "}" + : "=r"(d) + : "r"(a)); + ret = (uint16_t)d; + return ret; +} + +/** + * DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug. + * See the comment near READ_ONCE. + */ +__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return (uint8_t)ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint16_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint32_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET + uint64_t ret; + asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +#else + return READ_ONCE(*ptr); +#endif +} + +// Prevent code reordering from both compiler and GPU +__device__ static inline void IBGDA_MFENCE() { +#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE + asm volatile("fence.acq_rel.cta;" ::: "memory"); +#else + __threadfence_block(); +#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */ +} + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +__device__ static inline uint64_t ibgda_query_globaltimer() { + uint64_t ret; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory"); + return ret; +} +#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now, + uint64_t start, uint64_t idx, int *error) { + int status = 0; + + struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; + uint8_t opown; + uint8_t opcode; + uint16_t wqe_counter; + + if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) { + *error = -ETIME; + + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + + wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); + wqe_counter = BSWAP16(wqe_counter); + + printf( + "[%d] ibgda_poll_cq timeout:\n" + " cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n" + " wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n" + " while waiting for idx=%#lx.\n", + nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx), + ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter, + ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx); + status = -1; + } + return status; +} +#endif + +__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx, + int *error) { + int status = 0; + struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe; + + const uint32_t ncqes = cq->ncqes; + + uint8_t opown; + uint8_t opcode; + uint16_t wqe_counter; + uint16_t new_wqe_counter; + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + uint64_t start = ibgda_query_globaltimer(); + uint64_t now; +#endif + + uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx); + uint64_t new_cons_idx; + + assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI || + cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC)); + + if (unlikely(cons_idx >= idx)) goto out; + +#ifdef NVSHMEM_IBGDA_DEBUG + // We can skip opcode == MLX5_CQE_INVALID check because we have already + // initialized the CQ buffer to 0xff. With the QP depth range we enforce, + // cons_idx cannot progress unless wqe_counter read from the CQ buffer is + // a valid value. + do { + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + // TODO: Integrate timeout handler with the core NVSHMEM + now = ibgda_query_globaltimer(); + status = ibgda_check_poll_timeout(cq, now, start, idx, error); + if (status != 0) goto check_opcode; +#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */ + } while (unlikely(opcode == MLX5_CQE_INVALID)); + + // Prevent reordering of the opcode wait above + IBGDA_MFENCE(); +#endif /* NVSHMEM_IBGDA_DEBUG */ + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + start = ibgda_query_globaltimer(); +#endif + + // If idx is a lot greater than cons_idx, we might get incorrect result due + // to wqe_counter wraparound. We need to check prod_idx to be sure that idx + // has already been submitted. + while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx)) + ; + IBGDA_MFENCE(); + + do { + new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter); + new_wqe_counter = BSWAP16(new_wqe_counter); +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING + now = ibgda_query_globaltimer(); + status = ibgda_check_poll_timeout(cq, now, start, idx, error); + if (status != 0) goto check_opcode; + + // Observe progress. Reset the timer. + if (new_wqe_counter != wqe_counter) start = now; +#endif + wqe_counter = new_wqe_counter; + + // Another thread may have updated cons_idx. + cons_idx = ibgda_atomic_read(cq->cons_idx); + if (likely(cons_idx >= idx)) goto out; + } + // NOTE: This while loop is part of do while above. + // wqe_counter is the HW consumer index. However, we always maintain index + // + 1 in SW. To be able to compare with idx, we need to use wqe_counter + + // 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for + // sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than + // idx, and thus we need to wait. We don't need to wait when idx == + // wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case + // wraparound. + while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes))); + + // new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the + // MSB from idx. We also need to take care of wraparound. + ++wqe_counter; + new_cons_idx = + (idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0); + atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx); + +#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING +check_opcode: +#endif + + // NVSHMEM always treats CQE errors as fatal. + // Even if this error doesn't belong to the CQE in cons_idx, + // we will just report and terminate the process. + opown = ibgda_atomic_read(&cqe64->op_own); + opcode = opown >> 4; + + if (unlikely(opcode == MLX5_CQE_REQ_ERR)) { + ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64; + *error = cqe_err->syndrome; +#ifdef NVSHMEM_IBGDA_DEBUG + __be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter); + __be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn); + printf( + "[%d] got completion with err:\n" + " syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n" + " wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n" + " cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n", + nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd, + cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter), + BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx); +#endif /* NVSHMEM_IBGDA_DEBUG */ + status = -1; + } + +out: + // Prevent reordering of this function and subsequent instructions + IBGDA_MFENCE(); + + return status; +} + +__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) { + nvshmemi_ibgda_device_state_t *state = ibgda_get_state(); + uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx) + : ibgda_atomic_read(&qp->mvars.tx_wq.ready_head); + nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; + + int err = 0; + int status = ibgda_poll_cq(&cq, prod_idx, &err); + // TODO: Integrate the error handler with the core NVSHMEM +#ifdef NVSHMEM_IBGDA_DEBUG + if (status) { + printf("ibgda_poll_cq failed with error=%d.\n", err); + } +#endif + assert(likely(status == 0)); + return prod_idx; +} + +__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) { + int status = 0; + int err = 0; + uint16_t nwqes = qp->tx_wq.nwqes; + nwqes = nwqes / 2; + + // We don't want wqe_idx - nwqes to wraparound. + if (likely(wqe_idx >= nwqes)) { + nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq; + status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err); + // TODO: Integrate the error handler with the core NVSHMEM + if (status) { + printf("ibgda_poll_cq failed with error=%d.\n", err); + } + assert(likely(status == 0)); + } + IBGDA_MFENCE(); +} + +template __device__ static __forceinline__ uint64_t ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { auto mvars = &qp->mvars; - return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + uint64_t wqe_idx; + wqe_idx = atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + if (!nbi) { + uint64_t prod_idx = mvars->tx_wq.prod_idx; + uint64_t cons_idx = mvars->tx_wq.cons_idx; + uint64_t delta = prod_idx - cons_idx; + uint64_t cnt = qp->tx_wq.nwqes; + if (delta > cnt) { + printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta); + EP_DEVICE_ASSERT(delta <= cnt); + } + + // If last slot is available, all prior slots are also available. + ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes); + } + + // return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); + return wqe_idx; } __device__ static __forceinline__ void* ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { uint16_t cnt = qp->tx_wq.nwqes; + EP_DEVICE_ASSERT(cnt != 0); uint16_t idx = wqe_idx & (cnt - 1); return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); } @@ -325,6 +667,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) { st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); } +template __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { // Get lkey and rkey, store them into lanes @@ -354,7 +697,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, auto qp = ibgda_get_rc(dst_pe, qp_id); uint64_t base_wqe_idx = 0; if (lane_id == 0) - base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); if (lane_id < num_wqes) { auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id); @@ -367,6 +710,10 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, if (lane_id == 0) ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); __syncwarp(); + + // if (!nbi) { + // ibgda_quiet(qp); + // } } __device__ static __forceinline__ void ibgda_write_amo_add_wqe( diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index d6ad583..306dfe0 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -3,6 +3,7 @@ #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" +#include "ibgda_device.cuh" namespace deep_ep { @@ -710,10 +711,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (dst_rdma_rank != rdma_rank) { auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, - num_bytes_per_rdma_token * num_tokens_to_issue, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, + // rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token, + // num_bytes_per_rdma_token * num_tokens_to_issue, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t); + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { // Lighter fence for local RDMA rank @@ -725,8 +731,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } } } } @@ -926,8 +939,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv // Update remote head if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(lane_id, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(lane_id, nvl_rank)); + if (lane_id != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(lane_id, nvl_rank)); + } last_head = min_head; } @@ -1558,10 +1578,15 @@ combine(int4* combined_x, float* combined_topk_weights, if (sub_warp_id == kNumWarpsPerForwarder - 1) { if (dst_rdma_rank != rdma_rank) { auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; - nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, - num_chunked_tokens * num_bytes_per_rdma_token, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, + // rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token, + // num_chunked_tokens * num_bytes_per_rdma_token, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t); + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3); nvshmem_fence(); } else { memory_fence(); @@ -1569,9 +1594,17 @@ combine(int4* combined_x, float* combined_topk_weights, // Write new RDMA tail __syncwarp(); - if (lane_id == 0) - nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (lane_id == 0) { + // nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } + } } } @@ -1656,8 +1689,15 @@ combine(int4* combined_x, float* combined_topk_weights, for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, - translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + // nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, + // translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id); + } else { + nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank)); + } last_rdma_head = min_head; } } else { diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index c9f5879..d008098 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -59,7 +59,8 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks } // Normal operations use IBRC, while low-latency operations use IBGDA - if (low_latency_mode) { + bool internode_use_ibgda = true; + if (low_latency_mode or internode_use_ibgda) { nvshmemi_device_host_state_t* dev_state_ptr = nullptr; CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), nvshmemi_device_state_d)); diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 831a2e6..0665eb7 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -65,9 +65,10 @@ class Buffer: # Synchronize NVSHMEM unique IDs root_unique_id = None + internode_use_ibgda = True if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode: + if low_latency_mode or internode_use_ibgda: assert num_qps_per_rank > 0 os.environ['NVSHMEM_DISABLE_P2P'] = '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' diff --git a/tests/test_internode.py b/tests/test_internode.py index 5884a16..456ea64 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -218,17 +218,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in # noinspection PyUnboundLocalVariable def test_loop(local_rank: int, num_local_ranks: int): num_nodes = int(os.getenv('WORLD_SIZE', 1)) + num_sms = 24 + qp_num = num_sms // 2 rank, num_ranks, group = init_dist(local_rank, num_local_ranks) test_ll_compatibility = False if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num)) assert num_local_ranks == 8 and num_ranks > 8 torch.manual_seed(rank) - for i in (24, ): + for i in (num_sms, ): test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) if local_rank == 0: print('', flush=True)