mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-05 20:44:48 +00:00
In the Internode Normal Kernel, when using nvshmem ibrc for RDMA data transmission, a single QP is used for data transfer between two GPUs, which limits kernel performance in network card dual-port and RoCE network scenarios.
In our optimized Internode Normal Kernel, we implemented multiple QPs for data transmission between two GPUs, setting a different QP for each channel. Additionally, we modified the transmission method from IBRC to IBGAD. Through these optimizations, the Internode Normal Kernel achieves optimal performance in both H800 and H20 environments, with RDMA transmission performance nearly reaching the physical network performance limit. Using the current default statistical method, in 4-node H800 and H20 environments, RDMA performance can reach 60GB/s+.
This commit is contained in:
parent
a84a24808f
commit
5ab80c28f3
@ -11,6 +11,10 @@
|
||||
#include "exception.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
// #define NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
// #define IBGDA_POLL_TIMEOUT 4000000000LLU
|
||||
// #define NVSHMEM_IBGDA_DEBUG
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
|
||||
@ -242,15 +246,353 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey)
|
||||
*out_rkey = device_key.key;
|
||||
}
|
||||
|
||||
#ifndef likely
|
||||
#define likely(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
#ifndef unlikely
|
||||
#define unlikely(x) (__builtin_expect(!!(x), 0))
|
||||
#endif
|
||||
|
||||
#ifndef ACCESS_ONCE
|
||||
#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x))
|
||||
#endif
|
||||
|
||||
/**
|
||||
* DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug.
|
||||
* BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE.
|
||||
*/
|
||||
#ifndef READ_ONCE
|
||||
#define READ_ONCE(x) ACCESS_ONCE(x)
|
||||
#endif
|
||||
|
||||
#ifndef WRITE_ONCE
|
||||
#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v))
|
||||
#endif
|
||||
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
struct mlx5_err_cqe_ex {
|
||||
uint8_t rsvd0[32];
|
||||
__be32 srqn;
|
||||
uint8_t rsvd1[16];
|
||||
uint8_t hw_err_synd;
|
||||
uint8_t hw_synd_type;
|
||||
uint8_t vendor_err_synd;
|
||||
uint8_t syndrome;
|
||||
__be32 s_wqe_opcode_qpn;
|
||||
__be16 wqe_counter;
|
||||
uint8_t signature;
|
||||
uint8_t op_own;
|
||||
};
|
||||
typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t;
|
||||
#else
|
||||
typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t;
|
||||
#endif
|
||||
|
||||
__device__ static inline uint16_t BSWAP16(uint16_t x) {
|
||||
uint16_t ret;
|
||||
|
||||
uint32_t a = (uint32_t)x;
|
||||
uint32_t d;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b32 mask;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"mov.b32 mask, 0x4401;\n\t"
|
||||
"mov.b32 ign, 0x0;\n\t"
|
||||
"prmt.b32 %0, %1, ign, mask;\n\t"
|
||||
"}"
|
||||
: "=r"(d)
|
||||
: "r"(a));
|
||||
ret = (uint16_t)d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug.
|
||||
* See the comment near READ_ONCE.
|
||||
*/
|
||||
__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return (uint8_t)ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint32_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint64_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Prevent code reordering from both compiler and GPU
|
||||
__device__ static inline void IBGDA_MFENCE() {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE
|
||||
asm volatile("fence.acq_rel.cta;" ::: "memory");
|
||||
#else
|
||||
__threadfence_block();
|
||||
#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */
|
||||
}
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
__device__ static inline uint64_t ibgda_query_globaltimer() {
|
||||
uint64_t ret;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory");
|
||||
return ret;
|
||||
}
|
||||
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now,
|
||||
uint64_t start, uint64_t idx, int *error) {
|
||||
int status = 0;
|
||||
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
uint8_t opown;
|
||||
uint8_t opcode;
|
||||
uint16_t wqe_counter;
|
||||
|
||||
if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) {
|
||||
*error = -ETIME;
|
||||
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
|
||||
wqe_counter = BSWAP16(wqe_counter);
|
||||
|
||||
printf(
|
||||
"[%d] ibgda_poll_cq timeout:\n"
|
||||
" cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n"
|
||||
" wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n"
|
||||
" while waiting for idx=%#lx.\n",
|
||||
nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx),
|
||||
ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter,
|
||||
ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx);
|
||||
status = -1;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx,
|
||||
int *error) {
|
||||
int status = 0;
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
|
||||
uint8_t opown;
|
||||
uint8_t opcode;
|
||||
uint16_t wqe_counter;
|
||||
uint16_t new_wqe_counter;
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
uint64_t start = ibgda_query_globaltimer();
|
||||
uint64_t now;
|
||||
#endif
|
||||
|
||||
uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx);
|
||||
uint64_t new_cons_idx;
|
||||
|
||||
assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI ||
|
||||
cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC));
|
||||
|
||||
if (unlikely(cons_idx >= idx)) goto out;
|
||||
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
// We can skip opcode == MLX5_CQE_INVALID check because we have already
|
||||
// initialized the CQ buffer to 0xff. With the QP depth range we enforce,
|
||||
// cons_idx cannot progress unless wqe_counter read from the CQ buffer is
|
||||
// a valid value.
|
||||
do {
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
// TODO: Integrate timeout handler with the core NVSHMEM
|
||||
now = ibgda_query_globaltimer();
|
||||
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
|
||||
if (status != 0) goto check_opcode;
|
||||
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
|
||||
} while (unlikely(opcode == MLX5_CQE_INVALID));
|
||||
|
||||
// Prevent reordering of the opcode wait above
|
||||
IBGDA_MFENCE();
|
||||
#endif /* NVSHMEM_IBGDA_DEBUG */
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
start = ibgda_query_globaltimer();
|
||||
#endif
|
||||
|
||||
// If idx is a lot greater than cons_idx, we might get incorrect result due
|
||||
// to wqe_counter wraparound. We need to check prod_idx to be sure that idx
|
||||
// has already been submitted.
|
||||
while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx))
|
||||
;
|
||||
IBGDA_MFENCE();
|
||||
|
||||
do {
|
||||
new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
|
||||
new_wqe_counter = BSWAP16(new_wqe_counter);
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
now = ibgda_query_globaltimer();
|
||||
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
|
||||
if (status != 0) goto check_opcode;
|
||||
|
||||
// Observe progress. Reset the timer.
|
||||
if (new_wqe_counter != wqe_counter) start = now;
|
||||
#endif
|
||||
wqe_counter = new_wqe_counter;
|
||||
|
||||
// Another thread may have updated cons_idx.
|
||||
cons_idx = ibgda_atomic_read(cq->cons_idx);
|
||||
if (likely(cons_idx >= idx)) goto out;
|
||||
}
|
||||
// NOTE: This while loop is part of do while above.
|
||||
// wqe_counter is the HW consumer index. However, we always maintain index
|
||||
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
|
||||
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
|
||||
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
|
||||
// idx, and thus we need to wait. We don't need to wait when idx ==
|
||||
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
|
||||
// wraparound.
|
||||
while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes)));
|
||||
|
||||
// new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the
|
||||
// MSB from idx. We also need to take care of wraparound.
|
||||
++wqe_counter;
|
||||
new_cons_idx =
|
||||
(idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0);
|
||||
atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx);
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
check_opcode:
|
||||
#endif
|
||||
|
||||
// NVSHMEM always treats CQE errors as fatal.
|
||||
// Even if this error doesn't belong to the CQE in cons_idx,
|
||||
// we will just report and terminate the process.
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
if (unlikely(opcode == MLX5_CQE_REQ_ERR)) {
|
||||
ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64;
|
||||
*error = cqe_err->syndrome;
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
__be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter);
|
||||
__be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn);
|
||||
printf(
|
||||
"[%d] got completion with err:\n"
|
||||
" syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n"
|
||||
" wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n"
|
||||
" cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n",
|
||||
nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd,
|
||||
cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter),
|
||||
BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx);
|
||||
#endif /* NVSHMEM_IBGDA_DEBUG */
|
||||
status = -1;
|
||||
}
|
||||
|
||||
out:
|
||||
// Prevent reordering of this function and subsequent instructions
|
||||
IBGDA_MFENCE();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) {
|
||||
nvshmemi_ibgda_device_state_t *state = ibgda_get_state();
|
||||
uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx)
|
||||
: ibgda_atomic_read(&qp->mvars.tx_wq.ready_head);
|
||||
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
|
||||
|
||||
int err = 0;
|
||||
int status = ibgda_poll_cq(&cq, prod_idx, &err);
|
||||
// TODO: Integrate the error handler with the core NVSHMEM
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
if (status) {
|
||||
printf("ibgda_poll_cq failed with error=%d.\n", err);
|
||||
}
|
||||
#endif
|
||||
assert(likely(status == 0));
|
||||
return prod_idx;
|
||||
}
|
||||
|
||||
__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) {
|
||||
int status = 0;
|
||||
int err = 0;
|
||||
uint16_t nwqes = qp->tx_wq.nwqes;
|
||||
nwqes = nwqes / 2;
|
||||
|
||||
// We don't want wqe_idx - nwqes to wraparound.
|
||||
if (likely(wqe_idx >= nwqes)) {
|
||||
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
|
||||
status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err);
|
||||
// TODO: Integrate the error handler with the core NVSHMEM
|
||||
if (status) {
|
||||
printf("ibgda_poll_cq failed with error=%d.\n", err);
|
||||
}
|
||||
assert(likely(status == 0));
|
||||
}
|
||||
IBGDA_MFENCE();
|
||||
}
|
||||
|
||||
template <bool nbi = true>
|
||||
__device__ static __forceinline__ uint64_t
|
||||
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
|
||||
auto mvars = &qp->mvars;
|
||||
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
uint64_t wqe_idx;
|
||||
wqe_idx = atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
if (!nbi) {
|
||||
uint64_t prod_idx = mvars->tx_wq.prod_idx;
|
||||
uint64_t cons_idx = mvars->tx_wq.cons_idx;
|
||||
uint64_t delta = prod_idx - cons_idx;
|
||||
uint64_t cnt = qp->tx_wq.nwqes;
|
||||
if (delta > cnt) {
|
||||
printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta);
|
||||
EP_DEVICE_ASSERT(delta <= cnt);
|
||||
}
|
||||
|
||||
// If last slot is available, all prior slots are also available.
|
||||
ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes);
|
||||
}
|
||||
|
||||
// return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
return wqe_idx;
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void*
|
||||
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
|
||||
uint16_t cnt = qp->tx_wq.nwqes;
|
||||
EP_DEVICE_ASSERT(cnt != 0);
|
||||
uint16_t idx = wqe_idx & (cnt - 1);
|
||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
||||
}
|
||||
@ -325,6 +667,7 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
template <bool nbi = true>
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
@ -354,7 +697,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = 0;
|
||||
if (lane_id == 0)
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots<nbi>(qp, num_wqes);
|
||||
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
|
||||
if (lane_id < num_wqes) {
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
|
||||
@ -367,6 +710,10 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
|
||||
// if (!nbi) {
|
||||
// ibgda_quiet(qp);
|
||||
// }
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "exception.cuh"
|
||||
#include "launch.cuh"
|
||||
#include "utils.cuh"
|
||||
#include "ibgda_device.cuh"
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
@ -710,10 +711,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
|
||||
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
|
||||
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
num_bytes_per_rdma_token * num_tokens_to_issue,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
// rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
// num_bytes_per_rdma_token * num_tokens_to_issue,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
|
||||
nvshmem_fence();
|
||||
} else {
|
||||
// Lighter fence for local RDMA rank
|
||||
@ -725,8 +731,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
last_issued_tail += num_tokens_to_issue;
|
||||
num_tokens_to_send -= num_tokens_to_issue;
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
|
||||
} else {
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -926,8 +939,15 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Update remote head
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
if (lane_id != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id);
|
||||
} else {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
}
|
||||
last_head = min_head;
|
||||
}
|
||||
|
||||
@ -1558,10 +1578,15 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
|
||||
nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
num_chunked_tokens * num_bytes_per_rdma_token,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
// rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
// num_chunked_tokens * num_bytes_per_rdma_token,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
|
||||
nvshmem_fence();
|
||||
} else {
|
||||
memory_fence();
|
||||
@ -1569,9 +1594,17 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
|
||||
// Write new RDMA tail
|
||||
__syncwarp();
|
||||
if (lane_id == 0)
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (lane_id == 0) {
|
||||
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
|
||||
} else {
|
||||
nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1656,8 +1689,15 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i])
|
||||
min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]);
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
|
||||
} else {
|
||||
nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, NVSHMEM_SIGNAL_ADD,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
}
|
||||
last_rdma_head = min_head;
|
||||
}
|
||||
} else {
|
||||
|
@ -59,7 +59,8 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
|
||||
}
|
||||
|
||||
// Normal operations use IBRC, while low-latency operations use IBGDA
|
||||
if (low_latency_mode) {
|
||||
bool internode_use_ibgda = true;
|
||||
if (low_latency_mode or internode_use_ibgda) {
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
|
||||
|
@ -65,9 +65,10 @@ class Buffer:
|
||||
|
||||
# Synchronize NVSHMEM unique IDs
|
||||
root_unique_id = None
|
||||
internode_use_ibgda = True
|
||||
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
|
||||
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
|
||||
if low_latency_mode:
|
||||
if low_latency_mode or internode_use_ibgda:
|
||||
assert num_qps_per_rank > 0
|
||||
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
|
||||
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
|
||||
|
@ -218,17 +218,19 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
num_nodes = int(os.getenv('WORLD_SIZE', 1))
|
||||
num_sms = 24
|
||||
qp_num = num_sms // 2
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility = False
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
|
||||
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1))
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else qp_num))
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24, ):
|
||||
for i in (num_sms, ):
|
||||
test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print('', flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user