Use IBGDA only (#177)

This commit is contained in:
Shangyan Zhou 2025-05-28 16:40:14 +08:00 committed by GitHub
parent aae9fa9a6d
commit 9fe9021f29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 20 deletions

View File

@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
if (is_local_copy) {
// Fallback to NVSHMEM legacy API
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
atomicAdd(static_cast<unsigned long long*>(rptr), value);
} else {
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
}
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// Note that this implementation does not guarantee thread safety,
// so we must ensure that no other threads are concurrently using the same QP.
__device__ static __forceinline__ int
ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
int status = 0;
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
const uint32_t ncqes = cq->ncqes;
uint16_t wqe_counter;
uint16_t new_wqe_counter;
memory_fence_cta();
do {
new_wqe_counter = ld_na_relaxed(&cqe64->wqe_counter);
new_wqe_counter = HtoBE16(new_wqe_counter);
wqe_counter = new_wqe_counter;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
// Example:
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
while (((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes));
*cq->cons_idx = idx;
// Prevent reordering of this function and subsequent instructions
memory_fence_cta();
return status;
}
// Wait until wqe `idx - 1` is completed.
__device__ static __forceinline__ void
nvshmemi_ibgda_quiet(int dst_pe, int qp_id) {
auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t prod_idx = ld_na_relaxed(qp->tx_wq.prod_idx);
ibgda_poll_cq(qp->tx_wq.cq, prod_idx);
}
} // namespace deep_ep

View File

@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
}
template <bool kLowLatencyMode>
__forceinline__ __device__ void nvshmem_barrier_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
kLowLatencyMode ? void(nvshmem_barrier(rdma_team)) : nvshmem_barrier_all();
__forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all();
}
template <bool kLowLatencyMode, int kNumRDMARanks>
@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Issue send
// TODO: more light fence or barrier or signaling
// TODO: overlap EP barrier and NVL cleaning
if (thread_id < kNumRDMARanks) {
nvshmem_int_put_nbi(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id),
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
for (int i = 0; i < kNumRDMARanks; ++i) {
if (i != rdma_rank) {
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)),
reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)),
(NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int),
translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank), 0, lane_id, 0);
} else {
UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(i),
ld_volatile_global, st_na_global);
}
}
if (thread_id < kNumRDMARanks and thread_id != rdma_rank)
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0);
__syncthreads();
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// NVL buffers
@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
// Finally barrier
__syncthreads();
if (thread_id == 32)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
} else {
@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Iterate all RDMA ranks
int last_issued_tail = 0;
auto start_time = clock64();
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
// Timeout check
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n",
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
trap();
}
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
if (sm_id == 0) {
// Barrier for RDMA
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
__syncthreads();
// Clean
@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
#pragma unroll
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
nvshmem_fence();
__syncthreads();
// Barrier again
if (thread_id == 0)
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);

View File

@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
}
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
bool ibgda_is_initialized = false;
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
nvshmem_barrier_all();
return nvshmem_my_pe();
}