mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Fully remove barrier FIFO designs (#200)
* Fully remove FIFO slots * Fully remove FIFO buffers * Minor fix styles * Fix some typos * Bugs fixed * Cleanup `ibgda_poll_cq`
This commit is contained in:
@@ -14,16 +14,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
|
||||
auto sm_id = static_cast<int>(blockIdx.x);
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
|
||||
|
||||
if (sm_id == 0) {
|
||||
// Barrier first
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
int *per_rank_buffer, *per_expert_buffer;
|
||||
if (thread_id < kNumRanks) {
|
||||
@@ -46,9 +44,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
__syncthreads();
|
||||
|
||||
// Wait for all ranks to be finished
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Sum per-rank counts and return to CPU
|
||||
// Also pre-compute the prefix sum for data sending
|
||||
@@ -86,7 +82,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
|
||||
// Barrier
|
||||
memory_fence();
|
||||
__syncthreads();
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
} else {
|
||||
int dst_rank = sm_id - 1;
|
||||
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
|
||||
@@ -116,7 +112,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
|
||||
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
|
||||
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
|
||||
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
|
||||
cudaStream_t stream, int num_channels) {
|
||||
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
|
||||
@@ -124,7 +120,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
|
||||
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
|
||||
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
|
||||
buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
buffer_ptrs, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
constexpr int kNumThreads = 128;
|
||||
@@ -139,11 +135,9 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
|
||||
// A simplified version for cached handles
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Copy and clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
@@ -158,15 +152,15 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
}
|
||||
|
||||
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
|
||||
void** buffer_ptrs, int** task_fifo_ptrs,
|
||||
int head, int rank, int num_ranks, cudaStream_t stream) {
|
||||
void** buffer_ptrs, int** barrier_signal_ptrs,
|
||||
int rank, int num_ranks, cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
|
||||
rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
SETUP_LAUNCH_CONFIG(1, 128, stream);
|
||||
@@ -180,7 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
|
||||
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
|
||||
const bool* is_token_in_rank, const int* channel_prefix_matrix,
|
||||
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
|
||||
void **buffer_ptrs, int rank,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
|
||||
@@ -491,13 +485,11 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
|
||||
template<int kNumRanks>
|
||||
__global__ void
|
||||
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank) {
|
||||
int** barrier_signal_ptrs, int rank) {
|
||||
const auto sm_id = static_cast<int>(blockIdx.x);
|
||||
if (sm_id == 0) {
|
||||
// Barrier before cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
move_fifo_slots<kNumRanks>(head);
|
||||
__syncthreads();
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
|
||||
// Clean
|
||||
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
|
||||
@@ -509,7 +501,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
__syncthreads();
|
||||
|
||||
// Barrier after cleaning
|
||||
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
|
||||
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
|
||||
} else {
|
||||
const auto channel_id = sm_id - 1;
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
@@ -528,7 +520,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
int token_idx = token_idx_tail - lane_id, expected_head = 0;
|
||||
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
|
||||
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
|
||||
head = __shfl_sync(0xffffffff, current_head, i);
|
||||
const int head = __shfl_sync(0xffffffff, current_head, i);
|
||||
if (head < 0) {
|
||||
if (lane_id == i)
|
||||
expected_head = -last_head - 1;
|
||||
@@ -544,11 +536,11 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
|
||||
|
||||
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
|
||||
int num_recv_tokens, int num_memset_int,
|
||||
int** task_fifo_ptrs, int head, int rank, int num_ranks,
|
||||
int** barrier_signal_ptrs, int rank, int num_ranks,
|
||||
cudaStream_t stream) {
|
||||
#define CACHED_NOTIFY_COMBINE(ranks) \
|
||||
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
|
||||
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, barrier_signal_ptrs, rank); \
|
||||
break
|
||||
|
||||
const int num_threads = std::max(128, 32 * num_ranks);
|
||||
@@ -566,7 +558,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
|
||||
const dtype_t* x, const float* topk_weights,
|
||||
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
|
||||
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
|
||||
void **buffer_ptrs, int rank,
|
||||
void** buffer_ptrs, int rank,
|
||||
int num_max_send_tokens, int num_recv_buffer_tokens) {
|
||||
const auto num_sms = static_cast<int>(gridDim.x);
|
||||
const auto thread_id = static_cast<int>(threadIdx.x);
|
||||
|
||||
Reference in New Issue
Block a user