#include "configs.cuh" #include "buffer.cuh" #include "exception.cuh" #include "launch.cuh" #include "utils.cuh" namespace deep_ep { namespace intranode { template __global__ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, int num_channels, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto lane_id = thread_id % 32, warp_id = thread_id / 32, num_warps = num_threads / 32; if (sm_id == 0) { // Barrier first barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; } // After this loop: // - `per_rank_buffer[rank][i, j]` means the number of tokens from rank i to rank j // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j int num_experts_per_rank = num_experts / kNumRanks; if (thread_id < kNumRanks) { #pragma unroll for (int i = 0; i < kNumRanks; ++ i) per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; #pragma unroll for (int i = 0; i < num_experts_per_rank; ++ i) per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; } __syncthreads(); // Wait for all ranks to be finished barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); // Sum per-rank counts and return to CPU // Also pre-compute the prefix sum for data sending auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); if (thread_id < kNumRanks) { #pragma unroll for (int i = 1; i < kNumRanks; ++ i) local_per_rank_buffer[i * kNumRanks + thread_id] += local_per_rank_buffer[(i - 1) * kNumRanks + thread_id]; if (thread_id == rank) *moe_recv_counter_mapped = local_per_rank_buffer[(kNumRanks - 1) * kNumRanks + rank]; } // Sum per-experts counts and return to CPU auto local_per_expert_buffer = local_per_rank_buffer + kNumRanks * kNumRanks; if (thread_id < num_experts_per_rank) { int sum = 0; #pragma unroll for (int i = 0; i < kNumRanks; ++ i) sum += local_per_expert_buffer[i * num_experts_per_rank + thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; moe_recv_expert_counter_mapped[thread_id] = sum; } __syncthreads(); // Copy rank size prefix matrix to another tensor #pragma unroll for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) rank_prefix_matrix_copy[i] = local_per_rank_buffer[i]; // Extra memset for later communication queue #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) local_per_expert_buffer[i] = 0; // Barrier memory_fence(); __syncthreads(); barrier_device(task_fifo_ptrs, head, rank); } else { int dst_rank = sm_id - 1; for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // Iterate over tokens int count = 0; for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) count += is_token_in_rank[i * kNumRanks + dst_rank]; count = warp_reduce_sum(count); if (lane_id == 0) channel_prefix_matrix[dst_rank * num_channels + channel_id] = count; } __syncthreads(); // Pre-compute prefix sum for all channels if (thread_id == 0) { #pragma unroll for (int i = 1; i < num_channels; ++ i) channel_prefix_matrix[dst_rank * num_channels + i] += channel_prefix_matrix[dst_rank * num_channels + i - 1]; } } } void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, cudaStream_t stream, int num_channels) { #define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, notify_dispatch, \ num_tokens_per_rank, moe_recv_counter_mapped, \ num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ buffer_ptrs, task_fifo_ptrs, head, rank); \ break constexpr int kNumThreads = 128; EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE } template __global__ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank) { // A simplified version for cached handles barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); // Copy and clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto ptr = reinterpret_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) ptr[i] = rank_prefix_matrix[i]; #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[kNumRanks * kNumRanks + i] = 0; memory_fence(); __syncthreads(); // Barrier after cleaning barrier_device(task_fifo_ptrs, head, rank); } void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { #define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, cached_notify_dispatch, \ rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ break SETUP_LAUNCH_CONFIG(1, 128, stream); SWITCH_RANKS(CACHED_NOTIFY_DISPATCH_LAUNCH_CASE); #undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE } template __global__ void __launch_bounds__(kNumThreads, 1) dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void **buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const bool is_sender = sm_id % 2 == 0; EP_DEVICE_ASSERT(num_sms % 2 == 0); // Several warps are response for a single rank const auto num_threads_per_rank = kNumThreads / kNumRanks; const auto num_channels = num_sms / 2; const auto responsible_rank = (static_cast(thread_id)) / num_threads_per_rank; // Even-numbered blocks for sending, odd-numbered blocks for receiving. const auto responsible_channel = sm_id / 2; int num_experts_per_rank = num_experts / kNumRanks; EP_DEVICE_ASSERT(num_experts_per_rank > 0 or num_topk == 0); EP_DEVICE_ASSERT(num_topk <= 32); EP_DEVICE_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); EP_DEVICE_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); // Calculate pointers by the specific layout // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); int target_rank = is_sender ? rank : responsible_rank; auto num_channels_total = num_channels * kNumRanks; auto channel_rank_offset = responsible_channel * kNumRanks + target_rank; // Channel buffer metadata // Senders are responsible for tails, and receivers are responsible for heads // Stored on the receiver side // The retired signals are actually boolean flags, but to align with 16 bytes, we make it `int64_t` // `start_offset`: kNumChannels * kNumRanks * sizeof(int) // `end_offset`: kNumChannels * kNumRanks * sizeof(int) // `head_idx`: kNumChannels * kNumRanks * sizeof(int) // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) auto channel_start_offset = Buffer(ptr, num_channels_total, channel_rank_offset); auto channel_end_offset = Buffer(ptr, num_channels_total, channel_rank_offset); auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); // Channel data buffers, stored on the receiver side // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) // `topk_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(int64_t) // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) // `x_scales_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_scales * sizeof(float) auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); auto channel_topk_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); auto channel_x_scales_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); if (is_sender) { // Workers for sending constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; const auto send_thread_id = thread_id; const auto send_lane_id = send_thread_id % 32; const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(num_send_warps % kNumRanks == 0); // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 // NOTES: this is for distinguishing zero tokens if (send_lane_id == 0 and send_warp_id_in_rank == 0) { int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] : 0; st_relaxed_sys_global(channel_start_offset.buffer(), -value - 1); value = channel_prefix_matrix[responsible_rank * num_channels + responsible_channel]; st_relaxed_sys_global(channel_end_offset.buffer(), -value - 1); } __syncwarp(); // Get tasks int token_start_idx, token_end_idx; get_channel_task_range(num_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); // Iterate over all tokens and send by chunks int cached_channel_tail_idx = 0; for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { // Check destination queue emptiness, or wait a buffer to be released (rare cases) // NOTES: the head index received by different warps may not be the same auto start_time = clock64(); while (send_lane_id == 0) { // NOTES: we only consider the worst case, because counting the real numbers are time-consuming int num_used_slots = cached_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); if (num_recv_buffer_tokens - num_used_slots >= num_max_send_tokens) break; // Rare cases to loop again if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for dispatch senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); trap(); } } __syncwarp(); int chunk_token_idx = 0; while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data if (send_lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; // Skip if not selected if (not is_token_in_rank[token_idx * kNumRanks + responsible_rank]) { token_idx ++; continue; } // Get an empty slot int dst_slot_idx = (cached_channel_tail_idx ++) % num_recv_buffer_tokens; if (cached_channel_tail_idx % num_send_warps_per_rank == send_warp_id_in_rank) { // Copy data auto shifted_channel_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4; UNROLLED_WARP_COPY(5, send_lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, __ldg, st_na_global); // Copy source index if (send_lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); // Copy `topk_idx` and `topk_weights` with transformed index if (send_lane_id < num_topk) { // Top-k index int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; auto idx_value = __ldg(topk_idx + token_idx * num_topk + send_lane_id); idx_value = (idx_value >= recv_expert_begin and idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = idx_value; // Top-k weights auto weight_value = __ldg(topk_weights + token_idx * num_topk + send_lane_id); weight_value = (idx_value >= 0) ? weight_value : 0.0f; channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = weight_value; } // Copy `x_scales` #pragma unroll for (int i = send_lane_id; i < num_scales; i += 32) channel_x_scales_buffers[dst_slot_idx * num_scales + i] = __ldg(x_scales + token_idx * num_scales + i); } // Move token index chunk_token_idx ++, token_idx ++; } // Move tail index // NOTES: here all warps should share the same new tail asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); if (send_warp_id_in_rank == 0 and send_lane_id == 0) st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); } } else { // Workers for receiving and copying into buffer constexpr int num_recv_warps = kNumThreads / 32; constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; const auto recv_thread_id = thread_id; const auto recv_lane_id = recv_thread_id % 32; const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0); // Calculate offset first auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0; // Receive channel offset int total_offset, num_tokens_to_recv; while (recv_lane_id == 0 and (total_offset = ld_volatile_global(channel_start_offset.buffer())) == 0); while (recv_lane_id == 0 and (num_tokens_to_recv = ld_volatile_global(channel_end_offset.buffer())) == 0); if (recv_lane_id == 0) { total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; if (recv_warp_id_in_rank == 0) recv_channel_offset[responsible_rank * num_channels + responsible_channel] = total_offset; num_tokens_to_recv -= total_offset; } total_offset = __shfl_sync(0xffffffff, total_offset, 0); total_offset += rank_offset; num_tokens_to_recv = __shfl_sync(0xffffffff, num_tokens_to_recv, 0); // Shared tail indices for different warps __shared__ volatile int shared_channel_tail_idx[kNumRanks]; auto start_time = clock64(); int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same while (recv_thread_id_in_rank == 0) { cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());; // Ready to copy if (cached_channel_head_idx != cached_channel_tail_idx) { shared_channel_tail_idx[responsible_rank] = cached_channel_tail_idx; break; } // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for dispatch receivers, rank %d, responsible_channel = %d, tokens remained: %d\n", rank, responsible_channel, num_tokens_to_recv); trap(); } } // Synchronize queue tail asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); cached_channel_tail_idx = shared_channel_tail_idx[responsible_rank]; // Copy data int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; for (int chunk_idx = recv_warp_id_in_rank; chunk_idx < num_recv_tokens; chunk_idx += num_recv_warps_per_rank) { int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; auto shifted_buffer_x_int4 = channel_x_buffers.buffer() + token_idx_in_buffer * hidden_int4; auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; UNROLLED_WARP_COPY(5, recv_lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, ld_nc_global, st_na_global); } // Copy `src_idx` #pragma unroll 4 for (int chunk_idx = cached_channel_head_idx + recv_thread_id_in_rank; chunk_idx < cached_channel_tail_idx; chunk_idx += 32 * num_recv_warps_per_rank) recv_src_idx[total_offset + chunk_idx - cached_channel_head_idx] = ld_nc_global(channel_src_idx_buffers.buffer() + chunk_idx % num_recv_buffer_tokens); // Copy `topk_idx` and `topk_weights` #pragma unroll 4 for (int idx = recv_thread_id_in_rank; idx < num_recv_tokens * num_topk; idx += 32 * num_recv_warps_per_rank) { int chunk_idx = idx / num_topk, token_topk_idx = idx % num_topk; int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; auto recv_idx = static_cast(total_offset + chunk_idx) * num_topk + token_topk_idx; auto buffer_idx = token_idx_in_buffer * num_topk + token_topk_idx; recv_topk_idx[recv_idx] = ld_nc_global(channel_topk_idx_buffers.buffer() + buffer_idx); recv_topk_weights[recv_idx] = ld_nc_global(channel_topk_weights_buffers.buffer() + buffer_idx); } // Copy `x_scales` #pragma unroll 4 for (int i = recv_thread_id_in_rank; i < num_recv_tokens * num_scales; i += 32 * num_recv_warps_per_rank) { int chunk_idx = i / num_scales, scales_idx = i % num_scales; int token_idx_in_buffer = (cached_channel_head_idx + chunk_idx) % num_recv_buffer_tokens; recv_x_scales[static_cast(total_offset + chunk_idx) * num_scales + scales_idx] = ld_nc_global(channel_x_scales_buffers.buffer() + token_idx_in_buffer * num_scales + scales_idx); } // Move queue cached_channel_head_idx += num_recv_tokens; total_offset += num_recv_tokens; asm volatile("bar.sync %0, %1;" :: "r"(responsible_rank), "r"(num_threads_per_rank)); if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 and recv_lane_id == 0) st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); // Exit num_tokens_to_recv -= num_recv_tokens; } } } void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 512; #define DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, dispatch, \ reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ is_token_in_rank, channel_prefix_matrix, \ num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ buffer_ptrs, rank, \ num_max_send_tokens, num_recv_buffer_tokens); \ break // Even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(num_sms % 2 == 0); SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); SWITCH_RANKS(DISPATCH_LAUNCH_CASE); #undef DISPATCH_LAUNCH_CASE } template __global__ void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, int** task_fifo_ptrs, int head, int rank) { const auto sm_id = static_cast(blockIdx.x); if (sm_id == 0) { // Barrier before cleaning barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); // Clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto ptr = reinterpret_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; memory_fence(); __syncthreads(); // Barrier after cleaning barrier_device(task_fifo_ptrs, head, rank); } else { const auto channel_id = sm_id - 1; const auto thread_id = static_cast(threadIdx.x); const auto rank_id = thread_id / 32; const auto lane_id = thread_id % 32; if (rank_id >= kNumRanks) return; int token_start_idx, token_end_idx; get_channel_task_range(num_recv_tokens, num_channels, channel_id, token_start_idx, token_end_idx); // NOTES: `1 << 25` is a heuristic large number int last_head = 1 << 25; #pragma unroll for (int token_idx_tail = token_end_idx - 1; token_idx_tail >= token_start_idx; token_idx_tail -= 32) { int token_idx = token_idx_tail - lane_id, expected_head = 0; auto current_head = (token_idx >= token_start_idx) ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++ i) { head = __shfl_sync(0xffffffff, current_head, i); if (head < 0) { if (lane_id == i) expected_head = -last_head - 1; } else { last_head = head; } } if (current_head < 0 and token_idx >= token_start_idx) send_head[token_idx * kNumRanks + rank_id] = expected_head; } } } void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { #define CACHED_NOTIFY_COMBINE(ranks) \ LAUNCH_KERNEL(&cfg, cached_notify_combine, \ buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ break const int num_threads = std::max(128, 32 * num_ranks); EP_HOST_ASSERT(num_ranks <= num_threads); EP_HOST_ASSERT(num_threads <= 1024); EP_HOST_ASSERT(1 + num_channels <= num_channels * 2); SETUP_LAUNCH_CONFIG(1 + num_channels, num_threads, stream); SWITCH_RANKS(CACHED_NOTIFY_COMBINE); #undef CACHED_NOTIFY_COMBINE } template __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x, float* recv_topk_weights, const dtype_t* x, const float* topk_weights, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void **buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); const auto sm_id = static_cast(blockIdx.x); const auto num_channels = num_sms / 2; const bool is_sender = sm_id % 2 == 0; const int responsible_channel = sm_id / 2; EP_DEVICE_ASSERT(num_topk <= 32); constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); auto x_int4 = reinterpret_cast(x); auto recv_int4 = reinterpret_cast(recv_x); if (is_sender) { // Workers for sending // Several warps are responsible for a single rank constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto send_thread_id = thread_id; const auto send_lane_id = send_thread_id % 32; const auto send_rank_id = thread_id / num_threads_per_rank; const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; // Calculate pointers by the specific layout auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); auto num_channels_total = num_channels * kNumRanks; auto channel_rank_offset = responsible_channel * kNumRanks + rank; // Channel meta data // `head_idx`: kNumChannels * kNumRanks * sizeof(int) // `tail_idx`: kNumChannels * kNumRanks * sizeof(int) // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) auto channel_head_idx = Buffer(ptr, num_channels_total, channel_rank_offset); auto channel_tail_idx = Buffer(ptr, num_channels_total, channel_rank_offset); auto channel_x_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); auto channel_src_idx_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens, channel_rank_offset * num_recv_buffer_tokens); auto channel_topk_weights_buffers = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); // Get tasks // NOTES: `channel_offset` is already shifted int rank_offset = send_rank_id > 0 ? rank_prefix_matrix[(send_rank_id - 1) * kNumRanks + rank] : 0; int num_rank_tokens = rank_prefix_matrix[send_rank_id * kNumRanks + rank] - rank_offset; int channel_offset = channel_prefix_matrix[send_rank_id * num_channels + responsible_channel]; int num_channel_tokens = (responsible_channel == num_channels - 1 ? num_rank_tokens : channel_prefix_matrix[send_rank_id * num_channels + responsible_channel + 1]) - channel_offset; int token_start_idx = rank_offset + channel_offset, token_end_idx = rank_offset + channel_offset + num_channel_tokens; // Iterate over all tokens and send by chunks int current_channel_tail_idx = 0; for (int64_t token_idx = token_start_idx; token_idx < token_end_idx; ) { // Check destination queue emptiness, or wait a buffer to be released (rare cases) auto start_time = clock64(); int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); while (send_lane_id == 0) { // NOTES: we only consider the worst case, because counting the real numbers are time-consuming int num_used_slots = current_channel_tail_idx - ld_volatile_global(channel_head_idx.buffer()); if (num_recv_buffer_tokens - num_used_slots >= num_round_tokens) break; // Rare cases to loop again if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine senders, rank %d, responsible_channel = %d\n", rank, responsible_channel); trap(); } } __syncwarp(); // Send by chunk #pragma unroll for (int i = send_warp_id_in_rank; i < num_round_tokens; i += num_send_warps_per_rank) { // Get an empty slot int dst_slot_idx = (current_channel_tail_idx + i) % num_recv_buffer_tokens; // Copy data auto shifted_x_buffers = channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; UNROLLED_WARP_COPY(4, send_lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); // Send source index if (send_lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); // Send `topk_weights` if (num_topk > 0 and send_lane_id < num_topk) channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); } token_idx += num_round_tokens; current_channel_tail_idx += num_round_tokens; // Move tail index asm volatile("bar.sync %0, %1;" :: "r"(send_rank_id), "r"(num_threads_per_rank)); if (send_lane_id == 0 and send_warp_id_in_rank == 0) st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); } } else { // Workers for receiving // One warp for moving the queue head, others for reduction constexpr int num_recv_warps = kNumThreads / 32; const auto recv_warp_id = thread_id / 32; const auto recv_lane_id = thread_id % 32; EP_DEVICE_ASSERT(kNumRanks <= 32 and kNumThreads > 32); EP_DEVICE_ASSERT(thread_id >= 0 and kNumThreads % 32 == 0); // Shared head, tail and retired flags for receiver warps __shared__ volatile int warp_channel_head_idx[num_recv_warps][kNumRanks]; __shared__ volatile int channel_tail_idx[kNumRanks]; __shared__ volatile bool warp_retired[num_recv_warps]; if (thread_id < num_recv_warps) warp_retired[thread_id] = false; if (recv_lane_id < kNumRanks) warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; if (thread_id < kNumRanks) channel_tail_idx[thread_id] = 0; asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); if (thread_id < 32) { int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + recv_lane_id; int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; // Queue head updater int last_head = 0; while (recv_lane_id < kNumRanks) { // Check retired bool retired = true; #pragma unroll for (int i = 1; i < num_recv_warps; ++ i) retired = retired and warp_retired[i]; if (retired) break; // Update queue tail channel_tail_idx[recv_lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); // Update minimum head int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i]) min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); if (min_head != std::numeric_limits::max() and min_head > last_head) st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); } } else { // Receivers // Channel metadata // All lanes will use data buffer, but only rank lane will use `head/tail/src_idx` Buffer channel_x_buffers[kNumRanks]; Buffer channel_topk_weights_buffers[kNumRanks]; // Calculate pointers by the specific layout #pragma unroll for (int i = 0; i < kNumRanks; ++ i) { auto channel_rank_offset = responsible_channel * kNumRanks + i; auto num_channels_total = num_channels * kNumRanks; // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int)); // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) channel_x_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) ptr = reinterpret_cast(reinterpret_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) channel_topk_weights_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); } // The same tokens as the dispatch process int token_start_idx, token_end_idx; get_channel_task_range(num_recv_tokens, num_channels, responsible_channel, token_start_idx, token_end_idx); // Iterate over all tokens and combine for (int64_t token_idx = token_start_idx + recv_warp_id - 1; token_idx < token_end_idx; token_idx += num_recv_warps - 1) { // Read expected head int expected_head = -1; if (recv_lane_id < kNumRanks) expected_head = ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); auto start_time = clock64(); while (channel_tail_idx[recv_lane_id] <= expected_head and expected_head >= 0) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf("DeepEP timeout for combine receivers, rank %d, responsible_channel = %d, expect = %d\n", rank, responsible_channel, expected_head); trap(); } } __syncwarp(); // Broadcast current heads int num_topk_ranks = 0, topk_ranks[kNumRanks], slot_indices[kNumRanks]; #pragma unroll for (int i = 0; i < kNumRanks; ++ i) { auto expected_head_i = __shfl_sync(0xffffffff, expected_head, i); if (expected_head_i >= 0) { slot_indices[num_topk_ranks] = expected_head_i % num_recv_buffer_tokens; topk_ranks[num_topk_ranks ++] = i; } } // Reduce data #pragma unroll for (int i = recv_lane_id; i < hidden_int4; i += 32) { // Read buffers int4 recv_value_int4[kNumRanks]; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) recv_value_int4[j] = ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); // Reduce all-to-all results float values[kDtypePerInt4] = {0}; #pragma unroll for (int j = 0; j < num_topk_ranks; ++ j) { auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); #pragma unroll for (int k = 0; k < kDtypePerInt4; ++ k) values[k] += static_cast(recv_value_dtypes[k]); } // Cast back to `dtype_t` and write int4 out_int4; auto out_dtypes = reinterpret_cast(&out_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++ j) out_dtypes[j] = static_cast(values[j]); recv_int4[token_idx * hidden_int4 + i] = out_int4; } // Reduce `topk_weights` if (recv_lane_id < num_topk) { float value = 0; #pragma unroll for (int i = 0; i < num_topk_ranks; ++ i) value += ld_nc_global(channel_topk_weights_buffers[topk_ranks[i]].buffer() + slot_indices[i] * num_topk + recv_lane_id); recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; } // Update head if (recv_lane_id < kNumRanks) warp_channel_head_idx[recv_warp_id][recv_lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; } // Retired __syncwarp(); if (recv_lane_id == 0) warp_retired[recv_warp_id] = true; } } } void combine(cudaDataType_t type, void* recv_x, float* recv_topk_weights, const void* x, const float* topk_weights, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; #define COMBINE_LAUNCH_CASE(dtype, ranks) \ LAUNCH_KERNEL(&cfg, (combine), \ reinterpret_cast(recv_x), recv_topk_weights, \ reinterpret_cast(x), topk_weights, \ src_idx, rank_prefix_matrix, channel_prefix_matrix, \ send_head, num_tokens, num_recv_tokens, hidden, num_topk, \ buffer_ptrs, rank, \ num_max_send_tokens, num_recv_buffer_tokens); \ break #define COMBINE_DTYPE_LAUNCH_CASE(dtype) SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); break // Even-numbered blocks for sending, odd-numbered blocks for receiving EP_HOST_ASSERT(num_sms % 2 == 0); EP_HOST_ASSERT(kNumThreads >= num_ranks * 32); SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); SWITCH_TYPES(COMBINE_DTYPE_LAUNCH_CASE); #undef COMBINE_DTYPE_LAUNCH_CASE #undef COMBINE_LAUNCH_CASE } } // namespace intranode } // namespace deep_ep