Revert ibgda_device.cuh and remove some comments.

This commit is contained in:
Shangyan Zhou 2025-04-21 17:44:32 +08:00
parent 5ab80c28f3
commit e2c578485c
2 changed files with 4 additions and 365 deletions

View File

@ -11,10 +11,6 @@
#include "exception.cuh"
#include "utils.cuh"
// #define NVSHMEM_TIMEOUT_DEVICE_POLLING
// #define IBGDA_POLL_TIMEOUT 4000000000LLU
// #define NVSHMEM_IBGDA_DEBUG
namespace deep_ep {
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
@ -246,353 +242,15 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey)
*out_rkey = device_key.key;
}
#ifndef likely
#define likely(x) (__builtin_expect(!!(x), 1))
#endif
#ifndef unlikely
#define unlikely(x) (__builtin_expect(!!(x), 0))
#endif
#ifndef ACCESS_ONCE
#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x))
#endif
/**
* DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug.
* BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE.
*/
#ifndef READ_ONCE
#define READ_ONCE(x) ACCESS_ONCE(x)
#endif
#ifndef WRITE_ONCE
#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v))
#endif
#ifdef NVSHMEM_IBGDA_DEBUG
struct mlx5_err_cqe_ex {
uint8_t rsvd0[32];
__be32 srqn;
uint8_t rsvd1[16];
uint8_t hw_err_synd;
uint8_t hw_synd_type;
uint8_t vendor_err_synd;
uint8_t syndrome;
__be32 s_wqe_opcode_qpn;
__be16 wqe_counter;
uint8_t signature;
uint8_t op_own;
};
typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t;
#else
typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t;
#endif
__device__ static inline uint16_t BSWAP16(uint16_t x) {
uint16_t ret;
uint32_t a = (uint32_t)x;
uint32_t d;
asm volatile(
"{\n\t"
".reg .b32 mask;\n\t"
".reg .b32 ign;\n\t"
"mov.b32 mask, 0x4401;\n\t"
"mov.b32 ign, 0x0;\n\t"
"prmt.b32 %0, %1, ign, mask;\n\t"
"}"
: "=r"(d)
: "r"(a));
ret = (uint16_t)d;
return ret;
}
/**
* DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug.
* See the comment near READ_ONCE.
*/
__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) {
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
return (uint8_t)ret;
#else
return READ_ONCE(*ptr);
#endif
}
__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) {
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint16_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
return ret;
#else
return READ_ONCE(*ptr);
#endif
}
__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) {
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint32_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
#else
return READ_ONCE(*ptr);
#endif
}
__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) {
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
uint64_t ret;
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
#else
return READ_ONCE(*ptr);
#endif
}
// Prevent code reordering from both compiler and GPU
__device__ static inline void IBGDA_MFENCE() {
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE
asm volatile("fence.acq_rel.cta;" ::: "memory");
#else
__threadfence_block();
#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */
}
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
__device__ static inline uint64_t ibgda_query_globaltimer() {
uint64_t ret;
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory");
return ret;
}
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now,
uint64_t start, uint64_t idx, int *error) {
int status = 0;
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
uint8_t opown;
uint8_t opcode;
uint16_t wqe_counter;
if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) {
*error = -ETIME;
opown = ibgda_atomic_read(&cqe64->op_own);
opcode = opown >> 4;
wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
wqe_counter = BSWAP16(wqe_counter);
printf(
"[%d] ibgda_poll_cq timeout:\n"
" cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n"
" wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n"
" while waiting for idx=%#lx.\n",
nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx),
ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter,
ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx);
status = -1;
}
return status;
}
#endif
__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx,
int *error) {
int status = 0;
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
const uint32_t ncqes = cq->ncqes;
uint8_t opown;
uint8_t opcode;
uint16_t wqe_counter;
uint16_t new_wqe_counter;
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
uint64_t start = ibgda_query_globaltimer();
uint64_t now;
#endif
uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx);
uint64_t new_cons_idx;
assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI ||
cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC));
if (unlikely(cons_idx >= idx)) goto out;
#ifdef NVSHMEM_IBGDA_DEBUG
// We can skip opcode == MLX5_CQE_INVALID check because we have already
// initialized the CQ buffer to 0xff. With the QP depth range we enforce,
// cons_idx cannot progress unless wqe_counter read from the CQ buffer is
// a valid value.
do {
opown = ibgda_atomic_read(&cqe64->op_own);
opcode = opown >> 4;
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
// TODO: Integrate timeout handler with the core NVSHMEM
now = ibgda_query_globaltimer();
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
if (status != 0) goto check_opcode;
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
} while (unlikely(opcode == MLX5_CQE_INVALID));
// Prevent reordering of the opcode wait above
IBGDA_MFENCE();
#endif /* NVSHMEM_IBGDA_DEBUG */
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
start = ibgda_query_globaltimer();
#endif
// If idx is a lot greater than cons_idx, we might get incorrect result due
// to wqe_counter wraparound. We need to check prod_idx to be sure that idx
// has already been submitted.
while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx))
;
IBGDA_MFENCE();
do {
new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
new_wqe_counter = BSWAP16(new_wqe_counter);
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
now = ibgda_query_globaltimer();
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
if (status != 0) goto check_opcode;
// Observe progress. Reset the timer.
if (new_wqe_counter != wqe_counter) start = now;
#endif
wqe_counter = new_wqe_counter;
// Another thread may have updated cons_idx.
cons_idx = ibgda_atomic_read(cq->cons_idx);
if (likely(cons_idx >= idx)) goto out;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes)));
// new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the
// MSB from idx. We also need to take care of wraparound.
++wqe_counter;
new_cons_idx =
(idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0);
atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx);
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
check_opcode:
#endif
// NVSHMEM always treats CQE errors as fatal.
// Even if this error doesn't belong to the CQE in cons_idx,
// we will just report and terminate the process.
opown = ibgda_atomic_read(&cqe64->op_own);
opcode = opown >> 4;
if (unlikely(opcode == MLX5_CQE_REQ_ERR)) {
ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64;
*error = cqe_err->syndrome;
#ifdef NVSHMEM_IBGDA_DEBUG
__be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter);
__be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn);
printf(
"[%d] got completion with err:\n"
" syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n"
" wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n"
" cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n",
nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd,
cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter),
BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx);
#endif /* NVSHMEM_IBGDA_DEBUG */
status = -1;
}
out:
// Prevent reordering of this function and subsequent instructions
IBGDA_MFENCE();
return status;
}
__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) {
nvshmemi_ibgda_device_state_t *state = ibgda_get_state();
uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx)
: ibgda_atomic_read(&qp->mvars.tx_wq.ready_head);
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
int err = 0;
int status = ibgda_poll_cq(&cq, prod_idx, &err);
// TODO: Integrate the error handler with the core NVSHMEM
#ifdef NVSHMEM_IBGDA_DEBUG
if (status) {
printf("ibgda_poll_cq failed with error=%d.\n", err);
}
#endif
assert(likely(status == 0));
return prod_idx;
}
__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) {
int status = 0;
int err = 0;
uint16_t nwqes = qp->tx_wq.nwqes;
nwqes = nwqes / 2;
// We don't want wqe_idx - nwqes to wraparound.
if (likely(wqe_idx >= nwqes)) {
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err);
// TODO: Integrate the error handler with the core NVSHMEM
if (status) {
printf("ibgda_poll_cq failed with error=%d.\n", err);
}
assert(likely(status == 0));
}
IBGDA_MFENCE();
}
template <bool nbi = true>
__device__ static __forceinline__ uint64_t
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
auto mvars = &qp->mvars;
uint64_t wqe_idx;
wqe_idx = atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
if (!nbi) {
uint64_t prod_idx = mvars->tx_wq.prod_idx;
uint64_t cons_idx = mvars->tx_wq.cons_idx;
uint64_t delta = prod_idx - cons_idx;
uint64_t cnt = qp->tx_wq.nwqes;
if (delta > cnt) {
printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta);
EP_DEVICE_ASSERT(delta <= cnt);
}
// If last slot is available, all prior slots are also available.
ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes);
}
// return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
return wqe_idx;
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
}
__device__ static __forceinline__ void*
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
uint16_t cnt = qp->tx_wq.nwqes;
EP_DEVICE_ASSERT(cnt != 0);
uint16_t idx = wqe_idx & (cnt - 1);
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
}
@ -667,7 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
}
template <bool nbi = true>
__device__ static __forceinline__ void
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
// Get lkey and rkey, store them into lanes
@ -697,7 +354,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t base_wqe_idx = 0;
if (lane_id == 0)
base_wqe_idx = ibgda_reserve_wqe_slots<nbi>(qp, num_wqes);
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
if (lane_id < num_wqes) {
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
@ -710,10 +367,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
if (lane_id == 0)
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
__syncwarp();
// if (!nbi) {
// ibgda_quiet(qp);
// }
}
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(

View File

@ -711,14 +711,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (dst_rdma_rank != rdma_rank) {
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
// rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
// num_bytes_per_rdma_token * num_tokens_to_issue,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
} else {
@ -731,8 +727,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
if (lane_id == dst_rdma_rank) {
last_issued_tail += num_tokens_to_issue;
num_tokens_to_send -= num_tokens_to_issue;
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
@ -939,8 +933,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Update remote head
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
if (lane_id != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id);
@ -1578,14 +1570,10 @@ combine(int4* combined_x, float* combined_topk_weights,
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
if (dst_rdma_rank != rdma_rank) {
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
// rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
// num_chunked_tokens * num_bytes_per_rdma_token,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
nvshmem_fence();
} else {
@ -1595,8 +1583,6 @@ combine(int4* combined_x, float* combined_topk_weights,
// Write new RDMA tail
__syncwarp();
if (lane_id == 0) {
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
if (dst_rdma_rank != rdma_rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);