mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Use IBGDA only (#177)
This commit is contained in:
parent
aae9fa9a6d
commit
9fe9021f29
@ -413,8 +413,7 @@ __device__ static __forceinline__ void ibgda_write_amo_add_wqe(
|
||||
|
||||
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) {
|
||||
if (is_local_copy) {
|
||||
// Fallback to NVSHMEM legacy API
|
||||
nvshmemx_signal_op(static_cast<uint64_t*>(rptr), value, NVSHMEM_SIGNAL_ADD, pe);
|
||||
atomicAdd(static_cast<unsigned long long*>(rptr), value);
|
||||
} else {
|
||||
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
|
||||
|
||||
@ -446,4 +445,51 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
|
||||
return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
|
||||
}
|
||||
|
||||
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
|
||||
// Note that this implementation does not guarantee thread safety,
|
||||
// so we must ensure that no other threads are concurrently using the same QP.
|
||||
__device__ static __forceinline__ int
|
||||
ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
|
||||
int status = 0;
|
||||
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
|
||||
|
||||
const uint32_t ncqes = cq->ncqes;
|
||||
|
||||
uint16_t wqe_counter;
|
||||
uint16_t new_wqe_counter;
|
||||
|
||||
memory_fence_cta();
|
||||
|
||||
do {
|
||||
new_wqe_counter = ld_na_relaxed(&cqe64->wqe_counter);
|
||||
new_wqe_counter = HtoBE16(new_wqe_counter);
|
||||
wqe_counter = new_wqe_counter;
|
||||
}
|
||||
// NOTE: This while loop is part of do while above.
|
||||
// wqe_counter is the HW consumer index. However, we always maintain index
|
||||
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
|
||||
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
|
||||
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
|
||||
// idx, and thus we need to wait. We don't need to wait when idx ==
|
||||
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
|
||||
// wraparound.
|
||||
// Example:
|
||||
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
|
||||
while (((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes));
|
||||
|
||||
*cq->cons_idx = idx;
|
||||
// Prevent reordering of this function and subsequent instructions
|
||||
memory_fence_cta();
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Wait until wqe `idx - 1` is completed.
|
||||
__device__ static __forceinline__ void
|
||||
nvshmemi_ibgda_quiet(int dst_pe, int qp_id) {
|
||||
auto qp = ibgda_get_rc(dst_pe, qp_id);
|
||||
uint64_t prod_idx = ld_na_relaxed(qp->tx_wq.prod_idx);
|
||||
ibgda_poll_cq(qp->tx_wq.cq, prod_idx);
|
||||
}
|
||||
|
||||
} // namespace deep_ep
|
||||
|
||||
@ -193,8 +193,8 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank,
|
||||
}
|
||||
|
||||
template <bool kLowLatencyMode>
|
||||
__forceinline__ __device__ void nvshmem_barrier_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
|
||||
kLowLatencyMode ? void(nvshmem_barrier(rdma_team)) : nvshmem_barrier_all();
|
||||
__forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(const nvshmem_team_t& rdma_team) {
|
||||
kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all();
|
||||
}
|
||||
|
||||
template <bool kLowLatencyMode, int kNumRDMARanks>
|
||||
@ -223,7 +223,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
EP_DEVICE_ASSERT(num_warps > 1);
|
||||
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
|
||||
if (thread_id == 32)
|
||||
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
__syncthreads();
|
||||
@ -252,14 +252,25 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
// Issue send
|
||||
// TODO: more light fence or barrier or signaling
|
||||
// TODO: overlap EP barrier and NVL cleaning
|
||||
if (thread_id < kNumRDMARanks) {
|
||||
nvshmem_int_put_nbi(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), rdma_recv_num_tokens_mixed.send_buffer(thread_id),
|
||||
NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank));
|
||||
for (int i = 0; i < kNumRDMARanks; ++i) {
|
||||
if (i != rdma_rank) {
|
||||
nvshmemi_ibgda_put_nbi_warp<true>(reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)),
|
||||
reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)),
|
||||
(NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int),
|
||||
translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank), 0, lane_id, 0);
|
||||
} else {
|
||||
UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
|
||||
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
|
||||
rdma_recv_num_tokens_mixed.send_buffer(i),
|
||||
ld_volatile_global, st_na_global);
|
||||
}
|
||||
}
|
||||
if (thread_id < kNumRDMARanks and thread_id != rdma_rank)
|
||||
nvshmemi_ibgda_quiet(translate_dst_rdma_rank<kLowLatencyMode>(thread_id, nvl_rank), 0);
|
||||
|
||||
__syncthreads();
|
||||
if (thread_id == 0)
|
||||
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
__syncthreads();
|
||||
|
||||
// NVL buffers
|
||||
@ -345,7 +356,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
|
||||
// Finally barrier
|
||||
__syncthreads();
|
||||
if (thread_id == 32)
|
||||
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
|
||||
} else {
|
||||
@ -701,7 +712,14 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
|
||||
|
||||
// Iterate all RDMA ranks
|
||||
int last_issued_tail = 0;
|
||||
auto start_time = clock64();
|
||||
while (__any_sync(0xffffffff, num_tokens_to_send > 0)) {
|
||||
// Timeout check
|
||||
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) {
|
||||
printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail %d, num_tokens_to_send %d\n",
|
||||
channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send);
|
||||
trap();
|
||||
}
|
||||
for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) {
|
||||
// To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels
|
||||
int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks;
|
||||
@ -1103,7 +1121,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
if (sm_id == 0) {
|
||||
// Barrier for RDMA
|
||||
if (thread_id == 0)
|
||||
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
__syncthreads();
|
||||
|
||||
// Clean
|
||||
@ -1111,12 +1129,11 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
|
||||
#pragma unroll
|
||||
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
|
||||
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
|
||||
nvshmem_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Barrier again
|
||||
if (thread_id == 0)
|
||||
nvshmem_barrier_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
|
||||
} else if (sm_id == 1) {
|
||||
// Barrier for NVL
|
||||
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
|
||||
|
||||
@ -58,12 +58,6 @@ int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks
|
||||
EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID);
|
||||
}
|
||||
|
||||
// TODO: we still use `nvshmem_barrier` under IBRC mode, which should be switch to IBGDA mode later
|
||||
nvshmemi_device_host_state_t* dev_state_ptr = nullptr;
|
||||
CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&dev_state_ptr), nvshmemi_device_state_d));
|
||||
|
||||
bool ibgda_is_initialized = false;
|
||||
CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, &ibgda_is_initialized, sizeof(bool), cudaMemcpyHostToDevice));
|
||||
nvshmem_barrier_all();
|
||||
return nvshmem_my_pe();
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user