mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Low latency kernels use rdma atomic to support AR.
This commit is contained in:
parent
7128ba3e39
commit
2d0cf41dd1
@ -62,6 +62,12 @@ uint16_t HtoBE16(uint16_t x) {
|
|||||||
|
|
||||||
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
|
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
uint32_t add_data;
|
||||||
|
uint32_t field_boundary;
|
||||||
|
uint64_t reserved;
|
||||||
|
} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t;
|
||||||
|
|
||||||
__device__ static __forceinline__
|
__device__ static __forceinline__
|
||||||
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
|
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
|
||||||
return &nvshmemi_ibgda_device_state_d;
|
return &nvshmemi_ibgda_device_state_d;
|
||||||
@ -249,23 +255,6 @@ ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) {
|
|||||||
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait until wqe `idx - 1` is completed.
|
|
||||||
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. It can only be used for polling recv.
|
|
||||||
// Because we post recv and poll recv in the same thread, so we don't need to maintain queue status.
|
|
||||||
__device__ static __forceinline__ void
|
|
||||||
nvshmemi_ibgda_poll_recv(int dst_pe, int qp_id) {
|
|
||||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
|
||||||
auto cq = qp->rx_wq.cq;
|
|
||||||
|
|
||||||
const uint32_t ncqes = cq->ncqes;
|
|
||||||
auto *cqe64 = reinterpret_cast<struct mlx5_cqe64*>(cq->cqe);
|
|
||||||
auto old_cons_idx = *cq->cons_idx;
|
|
||||||
*cq->cons_idx = old_cons_idx + 1;
|
|
||||||
|
|
||||||
// Wait until `wqe_counter >= old_cons_idx`
|
|
||||||
while ((static_cast<uint16_t>(old_cons_idx - HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)) - 1) < ncqes));
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ static __forceinline__ void
|
__device__ static __forceinline__ void
|
||||||
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
|
nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits<uint32_t>::max()) {
|
||||||
// Get rkey
|
// Get rkey
|
||||||
@ -336,45 +325,6 @@ ibgda_write_empty_recv_wqe(void *out_wqe) {
|
|||||||
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<const int4*>(&data_seg));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static __forceinline__ uint64_t
|
|
||||||
nvshmemi_ibgda_allocate_recvs(nvshmemi_ibgda_device_qp* qp) {
|
|
||||||
auto mvars = &qp->mvars;
|
|
||||||
|
|
||||||
// Allocate if not enough
|
|
||||||
constexpr int kMinIBGDARecvs = 32;
|
|
||||||
auto resv_head = mvars->rx_wq.resv_head;
|
|
||||||
auto num_valid_slots = resv_head - mvars->rx_wq.cons_idx;
|
|
||||||
if (num_valid_slots < kMinIBGDARecvs) {
|
|
||||||
resv_head = mvars->rx_wq.cons_idx + qp->rx_wq.nwqes;
|
|
||||||
mvars->rx_wq.resv_head = resv_head;
|
|
||||||
|
|
||||||
// Ensure WQE is written before `dbrec`
|
|
||||||
__be32 dbrec_val;
|
|
||||||
__be32 *dbrec_ptr = qp->rx_wq.dbrec;
|
|
||||||
|
|
||||||
// Compared to sending, for each QP, we only post recv in a single thread,
|
|
||||||
// so we don't need to do synchronization here
|
|
||||||
// This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(wqe_idx & 0xffff))`
|
|
||||||
asm("{\n\t"
|
|
||||||
".reg .b32 dbrec_head_16b;\n\t"
|
|
||||||
".reg .b32 ign;\n\t"
|
|
||||||
"and.b32 dbrec_head_16b, %1, 0xffff;\n\t"
|
|
||||||
"prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t"
|
|
||||||
"}" : "=r"(dbrec_val)
|
|
||||||
: "r"(static_cast<uint32_t>(resv_head)));
|
|
||||||
st_na_release(dbrec_ptr, dbrec_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return old number of slots
|
|
||||||
return num_valid_slots;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ static __forceinline__ void
|
|
||||||
nvshmemi_ibgda_prepare_recvs(int dst_rank, int qp_id) {
|
|
||||||
// NOTES: only one thread can run this function
|
|
||||||
EP_DEVICE_ASSERT(nvshmemi_ibgda_allocate_recvs(ibgda_get_rc(dst_rank, qp_id)) > 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ static __forceinline__ void
|
__device__ static __forceinline__ void
|
||||||
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) {
|
||||||
// Get lkey and rkey, store them into lanes
|
// Get lkey and rkey, store them into lanes
|
||||||
@ -419,4 +369,62 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
|
|||||||
__syncwarp();
|
__syncwarp();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||||
|
nvshmemi_ibgda_device_qp_t *qp, const int &value,
|
||||||
|
uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
|
||||||
|
uint16_t wqe_idx, void **out_wqes) {
|
||||||
|
ibgda_ctrl_seg_t ctrl_seg = {0};
|
||||||
|
struct mlx5_wqe_raddr_seg raddr_seg;
|
||||||
|
struct mlx5_wqe_atomic_seg atomic_seg_1;
|
||||||
|
struct mlx5_wqe_data_seg data_seg;
|
||||||
|
|
||||||
|
auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
|
||||||
|
auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
|
||||||
|
auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
|
||||||
|
auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr));
|
||||||
|
|
||||||
|
raddr_seg.raddr = HtoBE64(raddr);
|
||||||
|
raddr_seg.rkey = rkey;
|
||||||
|
raddr_seg.reserved = 0;
|
||||||
|
|
||||||
|
// NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
|
||||||
|
ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000);
|
||||||
|
auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1);
|
||||||
|
atomic_32_masked_fa_seg->add_data = HtoBE32(value);
|
||||||
|
atomic_32_masked_fa_seg->field_boundary = 0;
|
||||||
|
|
||||||
|
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4);
|
||||||
|
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
|
||||||
|
|
||||||
|
data_seg.byte_count = HtoBE32(sizeof(int));
|
||||||
|
data_seg.lkey = lkey;
|
||||||
|
data_seg.addr = HtoBE64(laddr);
|
||||||
|
|
||||||
|
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization");
|
||||||
|
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization");
|
||||||
|
EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization");
|
||||||
|
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization");
|
||||||
|
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg));
|
||||||
|
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg));
|
||||||
|
st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1));
|
||||||
|
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
|
||||||
|
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||||
|
|
||||||
|
__be32 rkey;
|
||||||
|
uint64_t raddr;
|
||||||
|
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
|
||||||
|
|
||||||
|
void *wqe_ptrs[1];
|
||||||
|
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
|
||||||
|
wqe_ptrs[0] = ibgda_get_wqe_ptr(qp, my_wqe_idx);
|
||||||
|
|
||||||
|
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
|
||||||
|
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, wqe_ptrs);
|
||||||
|
|
||||||
|
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace deep_ep
|
} // namespace deep_ep
|
||||||
|
@ -215,9 +215,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
// Wait local sends issued and send expert counts
|
// Wait local sends issued and send expert counts
|
||||||
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
|
||||||
if (dst_rank != rank) {
|
if (dst_rank != rank) {
|
||||||
nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1,
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
|
||||||
dst_rank, dst_expert_local_idx, 0);
|
|
||||||
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
|
|
||||||
} else {
|
} else {
|
||||||
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
|
||||||
}
|
}
|
||||||
@ -262,13 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|||||||
int num_recv_tokens, recv_token_begin_idx;
|
int num_recv_tokens, recv_token_begin_idx;
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
|
||||||
if (sub_warp_id == 1 and lane_id == 0) {
|
if (sub_warp_id == 1 and lane_id == 0) {
|
||||||
if (src_rank != rank) {
|
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
||||||
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
|
|
||||||
num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank);
|
|
||||||
EP_DEVICE_ASSERT(num_recv_tokens != 0);
|
|
||||||
} else {
|
|
||||||
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
|
|
||||||
}
|
|
||||||
num_recv_tokens = -num_recv_tokens - 1;
|
num_recv_tokens = -num_recv_tokens - 1;
|
||||||
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
|
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
|
||||||
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
|
||||||
@ -439,7 +431,7 @@ combine(void* combined_x,
|
|||||||
if (sub_warp_id == 1 and lane_id == 0) {
|
if (sub_warp_id == 1 and lane_id == 0) {
|
||||||
while (ld_acquire_global(atomic_clean_flag) == 0);
|
while (ld_acquire_global(atomic_clean_flag) == 0);
|
||||||
if (dst_rank != rank) {
|
if (dst_rank != rank) {
|
||||||
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx, 0);
|
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
|
||||||
} else {
|
} else {
|
||||||
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
st_na_release(rdma_recv_flag + global_expert_idx, 1);
|
||||||
}
|
}
|
||||||
@ -456,16 +448,8 @@ combine(void* combined_x,
|
|||||||
// Wait all ranks to arrive and notify PCIe usage
|
// Wait all ranks to arrive and notify PCIe usage
|
||||||
if (responsible_expert_idx < num_experts) {
|
if (responsible_expert_idx < num_experts) {
|
||||||
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
|
||||||
if (sub_warp_id == 0 and lane_id == 0) {
|
if (sub_warp_id == 0 and lane_id == 0)
|
||||||
// TODO: refactor QP indices
|
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
||||||
auto src_rank = responsible_expert_idx / num_local_experts;
|
|
||||||
auto src_expert_idx = responsible_expert_idx % num_local_experts;
|
|
||||||
if (src_rank != rank) {
|
|
||||||
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
|
|
||||||
} else {
|
|
||||||
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cg::this_grid().sync();
|
cg::this_grid().sync();
|
||||||
|
|
||||||
|
@ -41,27 +41,6 @@ std::vector<uint8_t> get_unique_id() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void ibgda_initialize_recv_queue(int rank) {
|
|
||||||
auto thread_idx = static_cast<int>(threadIdx.x);
|
|
||||||
auto num_threads = static_cast<int>(blockDim.x);
|
|
||||||
|
|
||||||
auto dst_rank = static_cast<int>(blockIdx.x);
|
|
||||||
if (dst_rank != rank) {
|
|
||||||
for (int qp_id = thread_idx; qp_id < ibgda_get_state()->num_rc_per_pe; qp_id += num_threads) {
|
|
||||||
auto qp = ibgda_get_rc(dst_rank, qp_id);
|
|
||||||
|
|
||||||
// Clean some necessary variables
|
|
||||||
for (int i = 0; i < qp->rx_wq.nwqes; ++ i)
|
|
||||||
ibgda_write_empty_recv_wqe(ibgda_get_wqe_ptr(qp, i));
|
|
||||||
qp->mvars.rx_wq.resv_head = 0;
|
|
||||||
qp->mvars.rx_wq.cons_idx = 0;
|
|
||||||
|
|
||||||
// Allocate receive slots
|
|
||||||
nvshmemi_ibgda_allocate_recvs(qp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
|
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode) {
|
||||||
nvshmemx_uniqueid_t root_unique_id;
|
nvshmemx_uniqueid_t root_unique_id;
|
||||||
nvshmemx_init_attr_t attr;
|
nvshmemx_init_attr_t attr;
|
||||||
@ -85,10 +64,7 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
|
|||||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||||
|
|
||||||
bool ibgda_is_initialized = false;
|
bool ibgda_is_initialized = false;
|
||||||
cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice);
|
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
|
||||||
|
|
||||||
// Initialize recv queues for low-latency mode AR
|
|
||||||
ibgda_initialize_recv_queue<<<num_ranks, 128>>>(rank);
|
|
||||||
}
|
}
|
||||||
nvshmem_barrier_all();
|
nvshmem_barrier_all();
|
||||||
return nvshmem_my_pe();
|
return nvshmem_my_pe();
|
||||||
|
Loading…
Reference in New Issue
Block a user