mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-05-05 20:44:48 +00:00
Revert ibgda_device.cuh
and remove some comments.
This commit is contained in:
parent
5ab80c28f3
commit
e2c578485c
@ -11,10 +11,6 @@
|
||||
#include "exception.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
// #define NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
// #define IBGDA_POLL_TIMEOUT 4000000000LLU
|
||||
// #define NVSHMEM_IBGDA_DEBUG
|
||||
|
||||
namespace deep_ep {
|
||||
|
||||
EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth");
|
||||
@ -246,353 +242,15 @@ ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey)
|
||||
*out_rkey = device_key.key;
|
||||
}
|
||||
|
||||
#ifndef likely
|
||||
#define likely(x) (__builtin_expect(!!(x), 1))
|
||||
#endif
|
||||
|
||||
#ifndef unlikely
|
||||
#define unlikely(x) (__builtin_expect(!!(x), 0))
|
||||
#endif
|
||||
|
||||
#ifndef ACCESS_ONCE
|
||||
#define ACCESS_ONCE(x) (*(volatile typeof(x) *)&(x))
|
||||
#endif
|
||||
|
||||
/**
|
||||
* DO NOT use BSWAP(READ_ONCE(x)) as it could create a bug.
|
||||
* BSWAP is a pre-processor function. It will be unrolled to many READ_ONCE.
|
||||
*/
|
||||
#ifndef READ_ONCE
|
||||
#define READ_ONCE(x) ACCESS_ONCE(x)
|
||||
#endif
|
||||
|
||||
#ifndef WRITE_ONCE
|
||||
#define WRITE_ONCE(x, v) (ACCESS_ONCE(x) = (v))
|
||||
#endif
|
||||
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
struct mlx5_err_cqe_ex {
|
||||
uint8_t rsvd0[32];
|
||||
__be32 srqn;
|
||||
uint8_t rsvd1[16];
|
||||
uint8_t hw_err_synd;
|
||||
uint8_t hw_synd_type;
|
||||
uint8_t vendor_err_synd;
|
||||
uint8_t syndrome;
|
||||
__be32 s_wqe_opcode_qpn;
|
||||
__be16 wqe_counter;
|
||||
uint8_t signature;
|
||||
uint8_t op_own;
|
||||
};
|
||||
typedef struct mlx5_err_cqe_ex ibgda_mlx5_err_cqe_t;
|
||||
#else
|
||||
typedef struct mlx5_err_cqe ibgda_mlx5_err_cqe_t;
|
||||
#endif
|
||||
|
||||
__device__ static inline uint16_t BSWAP16(uint16_t x) {
|
||||
uint16_t ret;
|
||||
|
||||
uint32_t a = (uint32_t)x;
|
||||
uint32_t d;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .b32 mask;\n\t"
|
||||
".reg .b32 ign;\n\t"
|
||||
"mov.b32 mask, 0x4401;\n\t"
|
||||
"mov.b32 ign, 0x0;\n\t"
|
||||
"prmt.b32 %0, %1, ign, mask;\n\t"
|
||||
"}"
|
||||
: "=r"(d)
|
||||
: "r"(a));
|
||||
ret = (uint16_t)d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* DO NOT use BSWAP(ibgda_atomic_read(x)) as it could create a bug.
|
||||
* See the comment near READ_ONCE.
|
||||
*/
|
||||
__device__ static inline uint8_t ibgda_atomic_read(uint8_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b8 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return (uint8_t)ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint16_t ibgda_atomic_read(uint16_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint16_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b16 %0, [%1];" : "=h"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint32_t ibgda_atomic_read(uint32_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint32_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline uint64_t ibgda_atomic_read(uint64_t *ptr) {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_ATOMIC_READ_SET
|
||||
uint64_t ret;
|
||||
asm volatile("ld.relaxed.gpu.global.L1::no_allocate.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
|
||||
return ret;
|
||||
#else
|
||||
return READ_ONCE(*ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Prevent code reordering from both compiler and GPU
|
||||
__device__ static inline void IBGDA_MFENCE() {
|
||||
#ifdef NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE
|
||||
asm volatile("fence.acq_rel.cta;" ::: "memory");
|
||||
#else
|
||||
__threadfence_block();
|
||||
#endif /* NVSHMEMI_IBGDA_PTX_OPTIMIZATION_MFENCE */
|
||||
}
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
__device__ static inline uint64_t ibgda_query_globaltimer() {
|
||||
uint64_t ret;
|
||||
asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(ret)::"memory");
|
||||
return ret;
|
||||
}
|
||||
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
__device__ static inline int ibgda_check_poll_timeout(nvshmemi_ibgda_device_cq_t *cq, uint64_t now,
|
||||
uint64_t start, uint64_t idx, int *error) {
|
||||
int status = 0;
|
||||
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
uint8_t opown;
|
||||
uint8_t opcode;
|
||||
uint16_t wqe_counter;
|
||||
|
||||
if (unlikely(now - start > IBGDA_POLL_TIMEOUT)) {
|
||||
*error = -ETIME;
|
||||
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
|
||||
wqe_counter = BSWAP16(wqe_counter);
|
||||
|
||||
printf(
|
||||
"[%d] ibgda_poll_cq timeout:\n"
|
||||
" cons_idx=%#lx, prod_idx=%#lx, cqn=%#x, qpn=%#x, opcode=%#x\n"
|
||||
" wqe_counter=%#x, resv_head=%#lx, ready_head=%#lx\n"
|
||||
" while waiting for idx=%#lx.\n",
|
||||
nvshmemi_device_state_d.mype, ibgda_atomic_read(cq->cons_idx),
|
||||
ibgda_atomic_read(cq->prod_idx), cq->cqn, cq->qpn, opcode, wqe_counter,
|
||||
ibgda_atomic_read(cq->resv_head), ibgda_atomic_read(cq->ready_head), idx);
|
||||
status = -1;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ static inline int ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx,
|
||||
int *error) {
|
||||
int status = 0;
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
|
||||
uint8_t opown;
|
||||
uint8_t opcode;
|
||||
uint16_t wqe_counter;
|
||||
uint16_t new_wqe_counter;
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
uint64_t start = ibgda_query_globaltimer();
|
||||
uint64_t now;
|
||||
#endif
|
||||
|
||||
uint64_t cons_idx = ibgda_atomic_read(cq->cons_idx);
|
||||
uint64_t new_cons_idx;
|
||||
|
||||
assert(likely(cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_DCI ||
|
||||
cq->qp_type == NVSHMEMI_IBGDA_DEVICE_QP_TYPE_RC));
|
||||
|
||||
if (unlikely(cons_idx >= idx)) goto out;
|
||||
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
// We can skip opcode == MLX5_CQE_INVALID check because we have already
|
||||
// initialized the CQ buffer to 0xff. With the QP depth range we enforce,
|
||||
// cons_idx cannot progress unless wqe_counter read from the CQ buffer is
|
||||
// a valid value.
|
||||
do {
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
// TODO: Integrate timeout handler with the core NVSHMEM
|
||||
now = ibgda_query_globaltimer();
|
||||
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
|
||||
if (status != 0) goto check_opcode;
|
||||
#endif /* NVSHMEM_TIMEOUT_DEVICE_POLLING */
|
||||
} while (unlikely(opcode == MLX5_CQE_INVALID));
|
||||
|
||||
// Prevent reordering of the opcode wait above
|
||||
IBGDA_MFENCE();
|
||||
#endif /* NVSHMEM_IBGDA_DEBUG */
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
start = ibgda_query_globaltimer();
|
||||
#endif
|
||||
|
||||
// If idx is a lot greater than cons_idx, we might get incorrect result due
|
||||
// to wqe_counter wraparound. We need to check prod_idx to be sure that idx
|
||||
// has already been submitted.
|
||||
while (unlikely(ibgda_atomic_read(cq->prod_idx) < idx))
|
||||
;
|
||||
IBGDA_MFENCE();
|
||||
|
||||
do {
|
||||
new_wqe_counter = ibgda_atomic_read(&cqe64->wqe_counter);
|
||||
new_wqe_counter = BSWAP16(new_wqe_counter);
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
now = ibgda_query_globaltimer();
|
||||
status = ibgda_check_poll_timeout(cq, now, start, idx, error);
|
||||
if (status != 0) goto check_opcode;
|
||||
|
||||
// Observe progress. Reset the timer.
|
||||
if (new_wqe_counter != wqe_counter) start = now;
|
||||
#endif
|
||||
wqe_counter = new_wqe_counter;
|
||||
|
||||
// Another thread may have updated cons_idx.
|
||||
cons_idx = ibgda_atomic_read(cq->cons_idx);
|
||||
if (likely(cons_idx >= idx)) goto out;
|
||||
}
|
||||
// NOTE: This while loop is part of do while above.
|
||||
// wqe_counter is the HW consumer index. However, we always maintain index
|
||||
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
|
||||
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
|
||||
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
|
||||
// idx, and thus we need to wait. We don't need to wait when idx ==
|
||||
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
|
||||
// wraparound.
|
||||
while (unlikely(((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes)));
|
||||
|
||||
// new_cons_idx is uint64_t but wqe_counter is uint16_t. Thus, we get the
|
||||
// MSB from idx. We also need to take care of wraparound.
|
||||
++wqe_counter;
|
||||
new_cons_idx =
|
||||
(idx & ~(0xffffULL) | wqe_counter) + (((uint16_t)idx > wqe_counter) ? 0x10000ULL : 0x0);
|
||||
atomicMax((unsigned long long int *)cq->cons_idx, (unsigned long long int)new_cons_idx);
|
||||
|
||||
#ifdef NVSHMEM_TIMEOUT_DEVICE_POLLING
|
||||
check_opcode:
|
||||
#endif
|
||||
|
||||
// NVSHMEM always treats CQE errors as fatal.
|
||||
// Even if this error doesn't belong to the CQE in cons_idx,
|
||||
// we will just report and terminate the process.
|
||||
opown = ibgda_atomic_read(&cqe64->op_own);
|
||||
opcode = opown >> 4;
|
||||
|
||||
if (unlikely(opcode == MLX5_CQE_REQ_ERR)) {
|
||||
ibgda_mlx5_err_cqe_t *cqe_err = (ibgda_mlx5_err_cqe_t *)cqe64;
|
||||
*error = cqe_err->syndrome;
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
__be16 wqe_counter = ibgda_atomic_read(&cqe_err->wqe_counter);
|
||||
__be32 s_wqe_opcode_qpn = ibgda_atomic_read(&cqe_err->s_wqe_opcode_qpn);
|
||||
printf(
|
||||
"[%d] got completion with err:\n"
|
||||
" syndrome=%#x, vendor_err_synd=%#x, hw_err_synd=%#x, hw_synd_type=%#x,\n"
|
||||
" wqe_counter=%#x, s_wqe_opcode_qpn=%#x,\n"
|
||||
" cqn=%#x, cons_idx=%#lx, prod_idx=%#lx, idx=%#lx\n",
|
||||
nvshmemi_device_state_d.mype, cqe_err->syndrome, cqe_err->vendor_err_synd,
|
||||
cqe_err->hw_err_synd, cqe_err->hw_synd_type, BSWAP16(wqe_counter),
|
||||
BSWAP32(s_wqe_opcode_qpn), cq->cqn, cons_idx, ibgda_atomic_read(cq->prod_idx), idx);
|
||||
#endif /* NVSHMEM_IBGDA_DEBUG */
|
||||
status = -1;
|
||||
}
|
||||
|
||||
out:
|
||||
// Prevent reordering of this function and subsequent instructions
|
||||
IBGDA_MFENCE();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
__device__ static inline uint64_t ibgda_quiet(nvshmemi_ibgda_device_qp_t *qp) {
|
||||
nvshmemi_ibgda_device_state_t *state = ibgda_get_state();
|
||||
uint64_t prod_idx = state->use_async_postsend ? ibgda_atomic_read(qp->tx_wq.prod_idx)
|
||||
: ibgda_atomic_read(&qp->mvars.tx_wq.ready_head);
|
||||
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
|
||||
|
||||
int err = 0;
|
||||
int status = ibgda_poll_cq(&cq, prod_idx, &err);
|
||||
// TODO: Integrate the error handler with the core NVSHMEM
|
||||
#ifdef NVSHMEM_IBGDA_DEBUG
|
||||
if (status) {
|
||||
printf("ibgda_poll_cq failed with error=%d.\n", err);
|
||||
}
|
||||
#endif
|
||||
assert(likely(status == 0));
|
||||
return prod_idx;
|
||||
}
|
||||
|
||||
__device__ static inline void ibgda_wait_for_slot_availability(nvshmemi_ibgda_device_qp_t *qp, uint64_t wqe_idx) {
|
||||
int status = 0;
|
||||
int err = 0;
|
||||
uint16_t nwqes = qp->tx_wq.nwqes;
|
||||
nwqes = nwqes / 2;
|
||||
|
||||
// We don't want wqe_idx - nwqes to wraparound.
|
||||
if (likely(wqe_idx >= nwqes)) {
|
||||
nvshmemi_ibgda_device_cq_t cq = *qp->tx_wq.cq;
|
||||
status = ibgda_poll_cq(&cq, wqe_idx - nwqes, &err);
|
||||
// TODO: Integrate the error handler with the core NVSHMEM
|
||||
if (status) {
|
||||
printf("ibgda_poll_cq failed with error=%d.\n", err);
|
||||
}
|
||||
assert(likely(status == 0));
|
||||
}
|
||||
IBGDA_MFENCE();
|
||||
}
|
||||
|
||||
template <bool nbi = true>
|
||||
__device__ static __forceinline__ uint64_t
|
||||
ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) {
|
||||
auto mvars = &qp->mvars;
|
||||
uint64_t wqe_idx;
|
||||
wqe_idx = atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
if (!nbi) {
|
||||
uint64_t prod_idx = mvars->tx_wq.prod_idx;
|
||||
uint64_t cons_idx = mvars->tx_wq.cons_idx;
|
||||
uint64_t delta = prod_idx - cons_idx;
|
||||
uint64_t cnt = qp->tx_wq.nwqes;
|
||||
if (delta > cnt) {
|
||||
printf("prod_idx: %lu\tcons_idx: %lu\tcnt: %lu\tdelta: %lu\n", prod_idx, cons_idx, cnt, delta);
|
||||
EP_DEVICE_ASSERT(delta <= cnt);
|
||||
}
|
||||
|
||||
// If last slot is available, all prior slots are also available.
|
||||
ibgda_wait_for_slot_availability(qp, wqe_idx + num_wqes);
|
||||
}
|
||||
|
||||
// return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
return wqe_idx;
|
||||
return atomicAdd(reinterpret_cast<unsigned long long*>(&mvars->tx_wq.resv_head), static_cast<unsigned long long>(num_wqes));
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void*
|
||||
ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
|
||||
uint16_t cnt = qp->tx_wq.nwqes;
|
||||
EP_DEVICE_ASSERT(cnt != 0);
|
||||
uint16_t idx = wqe_idx & (cnt - 1);
|
||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
||||
}
|
||||
@ -667,7 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
|
||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||
}
|
||||
|
||||
template <bool nbi = true>
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||
// Get lkey and rkey, store them into lanes
|
||||
@ -697,7 +354,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t base_wqe_idx = 0;
|
||||
if (lane_id == 0)
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots<nbi>(qp, num_wqes);
|
||||
base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes);
|
||||
base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0);
|
||||
if (lane_id < num_wqes) {
|
||||
auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id);
|
||||
@ -710,10 +367,6 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
||||
if (lane_id == 0)
|
||||
ibgda_submit_requests<false>(qp, base_wqe_idx, num_wqes, message_idx);
|
||||
__syncwarp();
|
||||
|
||||
// if (!nbi) {
|
||||
// ibgda_quiet(qp);
|
||||
// }
|
||||
}
|
||||
|
||||
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
|
@ -711,14 +711,10 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens;
|
||||
EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens);
|
||||
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
// rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token,
|
||||
// num_bytes_per_rdma_token * num_tokens_to_issue,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
const size_t num_bytes_per_msg = (num_bytes_per_rdma_token * num_tokens_to_issue) * sizeof(int8_t);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
|
||||
nvshmem_fence();
|
||||
} else {
|
||||
@ -731,8 +727,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
if (lane_id == dst_rdma_rank) {
|
||||
last_issued_tail += num_tokens_to_issue;
|
||||
num_tokens_to_send -= num_tokens_to_issue;
|
||||
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
|
||||
@ -939,8 +933,6 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Update remote head
|
||||
if (min_head != std::numeric_limits<int>::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) {
|
||||
// nvshmemx_signal_op(rdma_channel_head.buffer(rdma_rank), min_head - last_head, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank));
|
||||
if (lane_id != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(lane_id, nvl_rank), channel_id);
|
||||
@ -1578,14 +1570,10 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
if (sub_warp_id == kNumWarpsPerForwarder - 1) {
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens;
|
||||
// nvshmemx_int8_put_nbi_warp(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
// rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token,
|
||||
// num_chunked_tokens * num_bytes_per_rdma_token,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
const size_t num_bytes_per_msg = (num_chunked_tokens * num_bytes_per_rdma_token) * sizeof(int8_t);
|
||||
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
const auto src_ptr = reinterpret_cast<uint64_t>(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token);
|
||||
nvshmemi_ibgda_put_nbi_warp<false>(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id, lane_id, 3);
|
||||
nvshmem_fence();
|
||||
} else {
|
||||
@ -1595,8 +1583,6 @@ combine(int4* combined_x, float* combined_topk_weights,
|
||||
// Write new RDMA tail
|
||||
__syncwarp();
|
||||
if (lane_id == 0) {
|
||||
// nvshmemx_signal_op(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, NVSHMEM_SIGNAL_ADD,
|
||||
// translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank));
|
||||
if (dst_rdma_rank != rdma_rank) {
|
||||
nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank), channel_id);
|
||||
|
Loading…
Reference in New Issue
Block a user