Fully remove barrier FIFO designs (#200)

* Fully remove FIFO slots

* Fully remove FIFO buffers

* Minor fix styles

* Fix some typos

* Bugs fixed

* Cleanup `ibgda_poll_cq`
This commit is contained in:
Chenggang Zhao
2025-06-10 16:23:20 +08:00
committed by GitHub
parent a16af40531
commit 8da2d7b38d
10 changed files with 121 additions and 181 deletions

View File

@@ -14,16 +14,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) {
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32;
if (sm_id == 0) {
// Barrier first
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
int *per_rank_buffer, *per_expert_buffer;
if (thread_id < kNumRanks) {
@@ -46,9 +44,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
__syncthreads();
// Wait for all ranks to be finished
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Sum per-rank counts and return to CPU
// Also pre-compute the prefix sum for data sending
@@ -86,7 +82,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped,
// Barrier
memory_fence();
__syncthreads();
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
int dst_rank = sm_id - 1;
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
@@ -116,7 +112,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix,
int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
cudaStream_t stream, int num_channels) {
#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
@@ -124,7 +120,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
buffer_ptrs, task_fifo_ptrs, head, rank); \
buffer_ptrs, barrier_signal_ptrs, rank); \
break
constexpr int kNumThreads = 128;
@@ -139,11 +135,9 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe
template<int kNumRanks>
__global__ void
cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) {
void** buffer_ptrs, int** barrier_signal_ptrs, int rank) {
// A simplified version for cached handles
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Copy and clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@@ -158,15 +152,15 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
}
void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int,
void** buffer_ptrs, int** task_fifo_ptrs,
int head, int rank, int num_ranks, cudaStream_t stream) {
void** buffer_ptrs, int** barrier_signal_ptrs,
int rank, int num_ranks, cudaStream_t stream) {
#define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_dispatch<ranks>, \
rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \
rank_prefix_matrix, num_memset_int, buffer_ptrs, barrier_signal_ptrs, rank); \
break
SETUP_LAUNCH_CONFIG(1, 128, stream);
@@ -180,7 +174,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to
int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights,
const bool* is_token_in_rank, const int* channel_prefix_matrix,
int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales,
void **buffer_ptrs, int rank,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x), sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x), lane_id = get_lane_id();
@@ -491,13 +485,11 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re
template<int kNumRanks>
__global__ void
cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank) {
int** barrier_signal_ptrs, int rank) {
const auto sm_id = static_cast<int>(blockIdx.x);
if (sm_id == 0) {
// Barrier before cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
move_fifo_slots<kNumRanks>(head);
__syncthreads();
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
// Clean
auto thread_id = static_cast<int>(threadIdx.x), num_threads = static_cast<int>(blockDim.x);
@@ -509,7 +501,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
__syncthreads();
// Barrier after cleaning
barrier_device<kNumRanks>(task_fifo_ptrs, head, rank);
barrier_block<kNumRanks>(barrier_signal_ptrs, rank);
} else {
const auto channel_id = sm_id - 1;
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -528,7 +520,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
int token_idx = token_idx_tail - lane_id, expected_head = 0;
auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1;
for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) {
head = __shfl_sync(0xffffffff, current_head, i);
const int head = __shfl_sync(0xffffffff, current_head, i);
if (head < 0) {
if (lane_id == i)
expected_head = -last_head - 1;
@@ -544,11 +536,11 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int
void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels,
int num_recv_tokens, int num_memset_int,
int** task_fifo_ptrs, int head, int rank, int num_ranks,
int** barrier_signal_ptrs, int rank, int num_ranks,
cudaStream_t stream) {
#define CACHED_NOTIFY_COMBINE(ranks) \
LAUNCH_KERNEL(&cfg, cached_notify_combine<ranks>, \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \
buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, barrier_signal_ptrs, rank); \
break
const int num_threads = std::max(128, 32 * num_ranks);
@@ -566,7 +558,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights,
const dtype_t* x, const float* topk_weights,
const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix,
int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk,
void **buffer_ptrs, int rank,
void** buffer_ptrs, int rank,
int num_max_send_tokens, int num_recv_buffer_tokens) {
const auto num_sms = static_cast<int>(gridDim.x);
const auto thread_id = static_cast<int>(threadIdx.x);