Fully remove barrier FIFO designs (#200)

* Fully remove FIFO slots

* Fully remove FIFO buffers

* Minor fix styles

* Fix some typos

* Bugs fixed

* Cleanup `ibgda_poll_cq`
This commit is contained in:
Chenggang Zhao 2025-06-10 16:23:20 +08:00 committed by GitHub
parent a16af40531
commit 8da2d7b38d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 121 additions and 181 deletions

View File

@ -18,10 +18,10 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes),
low_latency_mode(low_latency_mode),
comm_stream(at::cuda::getStreamFromPool(true)) {
// Task fifo memory
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);
// Common checks
EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits<int>::max() or num_rdma_bytes == 0));
@ -41,18 +41,17 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handle
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
buffer_ptrs_gpu = reinterpret_cast<void**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes);
buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);
// Set task fifo
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
task_fifo_ptrs_gpu = reinterpret_cast<int**>(reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes);
// Set barrier signals
barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
barrier_signal_ptrs_gpu = reinterpret_cast<int**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes);
// No need to synchronize, will do a full device sync during `sync`
CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream));
CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
}
// Create 32 MiB workspace
@ -91,8 +90,7 @@ Buffer::~Buffer() noexcept(false) {
if (num_nvl_bytes > 0) {
// Barrier
intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
move_fifo_slots();
intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
CUDA_CHECK(cudaDeviceSynchronize());
// Close remote IPC
@ -121,10 +119,6 @@ Buffer::~Buffer() noexcept(false) {
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
}
void Buffer::move_fifo_slots(int num_slots) {
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
}
bool Buffer::is_available() const {
return available;
}
@ -162,7 +156,7 @@ pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const {
torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype);
auto element_bytes = static_cast<int64_t>(elementSize(casted_dtype));
auto base_ptr = reinterpret_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto base_ptr = static_cast<uint8_t*>(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset;
auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes;
return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA));
}
@ -183,15 +177,15 @@ void Buffer::sync(const std::vector<int> &device_ids,
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
task_fifo_ptrs[i] = reinterpret_cast<int*>(reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
}
}
// Copy all buffer and task pointers to GPU
// Copy all buffer and barrier signal pointers to GPU
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaDeviceSynchronize());
}
@ -395,9 +389,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
// Copy rank prefix matrix and clean flags
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(), num_memset_int,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks,
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks,
comm_stream);
move_fifo_slots(2);
} else {
rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
@ -416,9 +409,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
num_tokens, is_token_in_rank.data_ptr<bool>(), channel_prefix_matrix.data_ptr<int>(),
rank_prefix_matrix.data_ptr<int>(),
num_memset_int, expert_alignment,
buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank,
buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank,
comm_stream, num_channels);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
@ -565,12 +557,9 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes);
intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr<int>(),
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
task_fifo_ptrs_gpu, head, rank, num_ranks,
barrier_signal_ptrs_gpu, rank, num_ranks,
comm_stream);
// NOTES: this function uses two FIFO slots (barrier before and after)
move_fifo_slots(2);
// Combine data
auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options());
EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail
@ -746,10 +735,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
nullptr, nullptr, nullptr,
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, true, low_latency_mode);
move_fifo_slots(2);
} else {
rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA));
recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA));
@ -769,10 +757,9 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional<torch::Te
gbl_channel_prefix_matrix.data_ptr<int>(), recv_gbl_rank_prefix_sum.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, low_latency_mode);
move_fifo_slots(3);
// Synchronize total received tokens and tokens per expert
auto start_time = std::chrono::high_resolution_clock::now();
@ -958,10 +945,9 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
rdma_channel_prefix_matrix.data_ptr<int>(), rdma_rank_prefix_sum.data_ptr<int>(), combined_nvl_head.data_ptr<int>(),
rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens,
buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens,
task_fifo_ptrs_gpu, head, rank, comm_stream,
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
move_fifo_slots(2);
// Launch data combine
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());

View File

@ -51,10 +51,9 @@ private:
// After IPC/NVSHMEM synchronization, this flag will be true
bool available = false;
// Task fifo
int head = 0;
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** task_fifo_ptrs_gpu = nullptr;
// Barrier signals
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
int** barrier_signal_ptrs_gpu = nullptr;
// Workspace
void* workspace = nullptr;
@ -75,9 +74,6 @@ private:
volatile int* low_latency_usage_flag = nullptr;
int* low_latency_usage_flag_mapped = nullptr;
private:
void move_fifo_slots(int num_slots = 1);
public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);

View File

@ -7,7 +7,7 @@ namespace deep_ep {
// Intranode runtime
namespace intranode {
void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
} // namespace intranode
@ -35,11 +35,11 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int num_sms);
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream);
void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset,
@ -51,7 +51,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
int num_max_send_tokens, int num_recv_buffer_tokens);
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream);
int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream);
void combine(cudaDataType_t type,
void* recv_x, float* recv_topk_weights,
@ -84,7 +84,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode);
@ -106,7 +106,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode);

View File

@ -2,7 +2,6 @@
#define NUM_MAX_NVL_PEERS 8
#define NUM_MAX_RDMA_PEERS 20
#define NUM_MAX_FIFO_SLOTS 32768
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
#define NUM_BUFFER_ALIGNMENT_BYTES 128

View File

@ -164,7 +164,7 @@ void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx
__device__ static __forceinline__ void
ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr,
__be32 rkey, uint16_t wqe_idx, void **out_wqes, uint32_t imm) {
__be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_inl_data_seg inl_seg;
@ -277,7 +277,7 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
__device__ static __forceinline__ void
ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey,
uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx,
void **out_wqes) {
void** out_wqes) {
ibgda_ctrl_seg_t ctrl_seg;
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_data_seg data_seg;
@ -373,7 +373,7 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
nvshmemi_ibgda_device_qp_t *qp, const int &value,
uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
uint16_t wqe_idx, void **out_wqes) {
uint16_t wqe_idx, void** out_wqes) {
ibgda_ctrl_seg_t ctrl_seg = {0};
struct mlx5_wqe_raddr_seg raddr_seg;
struct mlx5_wqe_atomic_seg atomic_seg_1;
@ -448,40 +448,27 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// Note that this implementation does not guarantee thread safety,
// so we must ensure that no other threads are concurrently using the same QP.
__device__ static __forceinline__ int
__device__ static __forceinline__ void
ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
int status = 0;
struct mlx5_cqe64 *cqe64 = (struct mlx5_cqe64 *)cq->cqe;
const auto cqe64 = static_cast<mlx5_cqe64*>(cq->cqe);
const uint32_t ncqes = cq->ncqes;
memory_fence_cta();
// NOTES: this while loop is part of do-while below.
// `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`.
// To be able to compare with the index, we need to use `wqe_counter + 1`.
// Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for
// sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1`
// That's why we use `- 2` here to make this case overflow.
uint16_t wqe_counter;
uint16_t new_wqe_counter;
memory_fence_cta();
do {
new_wqe_counter = ld_na_relaxed(&cqe64->wqe_counter);
new_wqe_counter = HtoBE16(new_wqe_counter);
wqe_counter = new_wqe_counter;
}
// NOTE: This while loop is part of do while above.
// wqe_counter is the HW consumer index. However, we always maintain index
// + 1 in SW. To be able to compare with idx, we need to use wqe_counter +
// 1. Because wqe_counter is uint16_t, it may wraparound. Still we know for
// sure that if idx - wqe_counter - 1 < ncqes, wqe_counter + 1 is less than
// idx, and thus we need to wait. We don't need to wait when idx ==
// wqe_counter + 1. That's why we use - (uint16_t)2 here to make this case
// wraparound.
// Example:
// if idx = 10, we wait until wqe_counter = 9, idx - wqe_counter - 2 = 65535 > ncqes.
while (((uint16_t)((uint16_t)idx - wqe_counter - (uint16_t)2) < ncqes));
wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter));
} while ((static_cast<uint16_t>(static_cast<uint16_t>(idx) - wqe_counter - static_cast<uint16_t>(2)) < ncqes));
*cq->cons_idx = idx;
// Prevent reordering of this function and subsequent instructions
memory_fence_cta();
return status;
// Prevent reordering of this function and later instructions
memory_fence_cta();
}
// Wait until wqe `idx - 1` is completed.

View File

@ -208,7 +208,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum,
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
const nvshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
@ -219,17 +219,15 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
if (sm_id == 0) {
// Communication with others
// Global barrier: the first warp do intra-node sync, the second warp do internode sync
// Global barrier: the first warp does intra-node sync, the second warp does internode sync
EP_DEVICE_ASSERT(num_warps > 1);
EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads);
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Send numbers of tokens per rank/expert to RDMA ranks
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
auto rdma_recv_num_tokens_mixed = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks);
// Clean up for later data dispatch
@ -285,7 +283,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
auto nvl_recv_num_tokens_per_expert = AsymBuffer<int>(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS);
// Clean up for later data dispatch
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes +
nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int));
#pragma unroll
@ -328,11 +326,9 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
}
memory_fence();
__syncthreads();
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Reduce number of tokens per rank/expert
// Reduce the number of tokens per rank/expert
EP_DEVICE_ASSERT(num_nvl_experts <= num_threads);
if (thread_id == 0) {
int sum = 0;
@ -359,8 +355,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
__syncthreads();
if (thread_id == 32)
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else {
// Calculate meta data
int dst_rdma_rank = sm_id - 1;
@ -423,7 +418,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank,
int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool low_latency_mode) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \
@ -439,7 +434,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \
gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \
rdma_buffer_ptr, \
buffer_ptrs, task_fifo_ptrs, head, rank, \
buffer_ptrs, barrier_signal_ptrs, rank, \
cpu_rdma_team); } break
constexpr int kNumThreads = 512;
@ -698,7 +693,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
// Release sequential lock
lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0;
} else if (warp_role == WarpRole::kRDMASenderCoordinator) {
// NOTES: in case of splitting the issued put at the end of the buffer
// NOTES: in case of splitting, the issued put at the end of the buffer
EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0);
// Synchronize shared memory
@ -861,7 +856,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) {
auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens;
void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token;
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes));
auto src_meta = ld_nc_global(reinterpret_cast<SourceMeta*>(static_cast<int8_t*>(shifted) + hidden_bytes));
lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0;
bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank);
if (lane_id == src_rdma_rank) {
@ -881,27 +876,27 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
nvl_channel_x.buffer() + dst_slot_idx * hidden_int4,
reinterpret_cast<int4*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<int4*>(shifted) + hidden_int4;
shifted = static_cast<int4*>(shifted) + hidden_int4;
// Copy source meta
if (lane_id == 0)
st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta);
shifted = reinterpret_cast<SourceMeta*>(shifted) + 1;
shifted = static_cast<SourceMeta*>(shifted) + 1;
// Copy `x_scales`
UNROLLED_WARP_COPY(1, lane_id, num_scales,
nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales,
reinterpret_cast<float*>(shifted),
ld_nc_global, st_na_global);
shifted = reinterpret_cast<float*>(shifted) + num_scales;
shifted = static_cast<float*>(shifted) + num_scales;
// Copy `topk_idx` and `topk_weights`
// NOTES: do not use `shifted` after this `if`, because only several lanes are shifted
if (lane_id < num_topk) {
// Read
auto idx_value = ld_nc_global(reinterpret_cast<int*>(shifted) + lane_id);
shifted = reinterpret_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(reinterpret_cast<float*>(shifted) + lane_id);
auto idx_value = ld_nc_global(static_cast<int*>(shifted) + lane_id);
shifted = static_cast<int*>(shifted) + num_topk;
auto weight_value = ld_nc_global(static_cast<float*>(shifted) + lane_id);
// Transform and write
idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1;
@ -1107,7 +1102,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
int* combined_rdma_head, int num_combined_tokens, int num_channels,
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks,
bool is_cached_dispatch, const nvshmem_team_t rdma_team) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x);
@ -1127,7 +1122,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
__syncthreads();
// Clean
auto rdma_buffer_ptr_int = reinterpret_cast<int*>(rdma_buffer_ptr);
auto rdma_buffer_ptr_int = static_cast<int*>(rdma_buffer_ptr);
#pragma unroll
for (int i = thread_id; i < rdma_num_int_clean; i += num_threads)
rdma_buffer_ptr_int[rdma_clean_offset + i] = 0;
@ -1138,12 +1133,10 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
nvshmem_sync_with_same_gpu_idx<kLowLatencyMode>(rdma_team);
} else if (sm_id == 1) {
// Barrier for NVL
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
__syncthreads();
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
// Clean
auto nvl_buffer_ptr_int = reinterpret_cast<int*>(buffer_ptrs[nvl_rank]);
auto nvl_buffer_ptr_int = static_cast<int*>(buffer_ptrs[nvl_rank]);
#pragma unroll
for (int i = thread_id; i < nvl_num_int_clean; i += num_threads)
nvl_buffer_ptr_int[nvl_clean_offset + i] = 0;
@ -1151,8 +1144,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
__syncthreads();
// Barrier again
barrier_device<NUM_MAX_NVL_PEERS>(task_fifo_ptrs, head, nvl_rank);
move_fifo_slots<NUM_MAX_NVL_PEERS>(head);
barrier_block<NUM_MAX_NVL_PEERS>(barrier_signal_ptrs, nvl_rank);
} else if (sm_id == 2) {
if (is_cached_dispatch)
return;
@ -1213,7 +1205,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head,
void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens,
int** task_fifo_ptrs, int head, int rank, cudaStream_t stream,
int** barrier_signal_ptrs, int rank, cudaStream_t stream,
int64_t num_rdma_bytes, int64_t num_nvl_bytes,
bool is_cached_dispatch, bool low_latency_mode) {
const int num_threads = std::max(128, 32 * num_channels);
@ -1237,7 +1229,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to
combined_rdma_head, num_combined_tokens, num_channels,
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head,
rdma_buffer_ptr,
buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks,
buffer_ptrs, barrier_signal_ptrs, rank, num_ranks,
is_cached_dispatch, cpu_rdma_team);
}
@ -1397,7 +1389,7 @@ combine(int4* combined_x, float* combined_topk_weights,
if (__all_sync(0xffffffff, token_start_idx >= token_end_idx))
break;
// Decide next RDMA buffer to send
// Decide the next RDMA buffer to send
bool is_lane_ready = false;
auto start_time = clock64();
while (true) {
@ -1575,8 +1567,8 @@ combine(int4* combined_x, float* combined_topk_weights,
combine_token<NUM_MAX_NVL_PEERS, dtype_t, NUM_MAX_NVL_PEERS>(expected_head >= 0,
expected_head, lane_id,
hidden_int4, num_topk,
reinterpret_cast<int4*>(shifted),
reinterpret_cast<float*>(reinterpret_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
static_cast<int4*>(shifted),
reinterpret_cast<float*>(static_cast<int8_t*>(shifted) + hidden_bytes + sizeof(SourceMeta)),
num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn);
// Update head

View File

@ -22,7 +22,7 @@ __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
for (int i = thread_id; i < num_clean_int_1; i += kNumThreads)
clean_1[i] = 0;
// Barrier after cleaning (make sure low-latency mode work fine)
// Barrier after cleaning (make sure the low-latency mode works fine)
nvshmemx_barrier_all_block();
}

View File

@ -14,16 +14,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
if (sm_id == 0) {
// Barrier first
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
@ -46,9 +44,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
__syncthreads();
// Wait for all ranks to be finished
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
@ -86,7 +82,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
// Barrier
memory_fence();
__syncthreads();
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
int dst_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
@ -116,7 +112,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
@ -124,7 +120,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
buffer_ptrs, task_fifo_ptrs, head, rank); \
buffer_ptrs, barrier_signal_ptrs, rank); \
break
constexpr int kNumThreads = 128;
@ -139,11 +135,9 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
template<int kNumRanks>
__global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
// A simplified version for cached handles
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@ -158,15 +152,15 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs,
int head, int rank, int num_ranks, cudaStream_t stream) {
void** buffer_ptrs, int** barrier_signal_ptrs,
int rank, int num_ranks, cudaStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, 128, stream);
@ -180,7 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void **buffer_ptrs, int rank,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
@ -491,13 +485,11 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
template<int kNumRanks>
__global__ void
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank) {
int** barrier_signal_ptrs, int rank) {
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@ -509,7 +501,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
const auto channel_id = sm_id - 1;
const auto thread_id = static_cast<int>(threadIdx.x);
@ -528,7 +520,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
int token_idx = token_idx_tail - lane_id, expected_head = 0;
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
head = __shfl_sync(0xffffffff, current_head, i);
const int head = __shfl_sync(0xffffffff, current_head, i);
if (head < 0) {
if (lane_id == i)
expected_head = -last_head - 1;
@ -544,11 +536,11 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks,
int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, barrier_signal_ptrs, rank); \
break
const int num_threads = std::max(128, 32 * num_ranks);
@ -566,7 +558,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void **buffer_ptrs, int rank,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);

View File

@ -12,13 +12,13 @@ namespace deep_ep {
namespace intranode {
template<int kNumRanks>
__global__ void barrier(int** task_fifo_ptrs, int head, int rank) {
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
__global__ void barrier(int** barrier_signal_ptrs, int rank) {
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) {
void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) {
#define BARRIER_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, barrier<ranks>, task_fifo_ptrs, head, rank); \
LAUNCH_KERNEL(&cfg, barrier<ranks>, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, 32, stream);

View File

@ -396,44 +396,32 @@ __forceinline__ __device__ int get_lane_id() {
return lane_id;
}
template <int kNumRanks>
__forceinline__ __device__ void move_fifo_slots(int &head) {
head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS;
}
template <int kNumRanks>
__device__ __forceinline__ bool not_finished(int *task, int expected) {
auto result = false;
auto lane_id = threadIdx.x % 32;
if (lane_id < kNumRanks)
result = ld_volatile_global(task + lane_id) != expected;
return __any_sync(0xffffffff, result);
}
template <int kNumRanks>
__forceinline__ __device__ void
timeout_check(int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) {
barrier_block(int** barrier_signal_ptrs, int rank) {
auto thread_id = static_cast<int>(threadIdx.x);
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
}
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
// Check timeout
auto start_time = clock64();
while (not_finished<kNumRanks>(task_fifo_ptrs[rank] + head, expected)) {
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) {
printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank);
while (true) {
auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0;
if (__all_sync(0xffffffff, value <= 0))
break;
if (clock64() - start_time > NUM_TIMEOUT_CYCLES and get_lane_id() == 0) {
printf("DeepEP timeout check failed: rank = %d, thread = %d)\n", rank, thread_id);
trap();
}
}
}
template <int kNumRanks>
__forceinline__ __device__ void
barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) {
auto thread_id = static_cast<int>(threadIdx.x);
EP_DEVICE_ASSERT(kNumRanks <= 32);
if (thread_id < kNumRanks) {
atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG);
memory_fence();
atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG);
}
timeout_check<kNumRanks>(task_fifo_ptrs, head, rank, 0, tag);
__syncthreads();
}
} // namespace deep_ep