diff --git a/README.md b/README.md index d532379..fafe9d9 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], allocate_on_comm_stream=previous_event is not None) # Do MoE dispatch # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph + # Unless you specify `num_worst_tokens`, but this flag is for intranode only # For more advanced usages, please refer to the docs of the `dispatch` function recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 7ea5fcb..66c1964 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -284,7 +284,8 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + int expert_alignment, int num_worst_tokens, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. @@ -412,25 +413,34 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(*moe_recv_counter); + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; - // Read per-expert count - bool ready = (num_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); - if (ready) - break; + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; - // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) - throw std::runtime_error("DeepEP error: CPU recv timeout"); + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } - num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors @@ -472,7 +482,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(), x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), - num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, + num_tokens, num_worst_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 85e723c..7b2b0b3 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -108,7 +108,8 @@ public: const std::optional& topk_idx, const std::optional& topk_weights, const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + int expert_alignment, int num_worst_tokens, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream); std::tuple, std::optional> intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index df7aece..d10044e 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -45,7 +45,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens); diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index cd545a0..6f7c701 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -25,7 +25,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { - per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); + per_rank_buffer = static_cast(buffer_ptrs[thread_id]); per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; } @@ -48,7 +48,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // Sum per-rank counts and return to CPU // Also pre-compute the prefix sum for data sending - auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); + auto local_per_rank_buffer = static_cast(buffer_ptrs[rank]); if (thread_id < kNumRanks) { #pragma unroll for (int i = 1; i < kNumRanks; ++ i) @@ -141,7 +141,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, // Copy and clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); + auto ptr = static_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) ptr[i] = rank_prefix_matrix[i]; @@ -173,7 +173,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); @@ -196,7 +196,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to // Calculate pointers by the specific layout // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); + auto ptr = reinterpret_cast(static_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); int target_rank = is_sender ? rank : responsible_rank; auto num_channels_total = num_channels * kNumRanks; auto channel_rank_offset = responsible_channel * kNumRanks + target_rank; @@ -286,7 +286,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to int chunk_token_idx = 0; while (chunk_token_idx < num_max_send_tokens and token_idx < token_end_idx) { - // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send subsequent data + // NOTES: for the same token, the warp assigned to save `send_head` may be different from the warp assigned to send the following data if (lane_id == 0 and token_idx % num_send_warps_per_rank == send_warp_id_in_rank) send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] ? cached_channel_tail_idx : -1; @@ -349,7 +349,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to EP_DEVICE_ASSERT(recv_thread_id >= 0 and num_recv_warps % kNumRanks == 0); // Calculate offset first - auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); + auto rank_prefix_matrix = static_cast(buffer_ptrs[rank]); int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] : 0; // Receive channel offset @@ -372,7 +372,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to auto start_time = clock64(); int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { - // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are same + // NOTES: unlike the sender, the receiver must ensure that the tail indices hold by different warps are the same while (recv_thread_id_in_rank == 0) { cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer());; @@ -450,12 +450,25 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to if (lane_id == 0) tma_store_wait(); } + + + // Clean unused `recv_topk_idx` as -1 + if (num_worst_tokens > 0) { + auto rank_prefix_matrix = static_cast(buffer_ptrs[rank]); + const auto num_recv_tokens = rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank]; + const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads; + const auto clean_end = num_worst_tokens * num_topk; + const auto clean_stride = num_sms * kNumThreads; + #pragma unroll + for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) + recv_topk_idx[i] = -1; + } } void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, void** buffer_ptrs, int rank, int num_ranks, cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; @@ -470,7 +483,7 @@ void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* re reinterpret_cast(recv_x), recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, recv_channel_offset, \ send_head, reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ is_token_in_rank, channel_prefix_matrix, \ - num_tokens, hidden_int4, num_topk, num_experts, num_scales, \ + num_tokens, num_worst_tokens, hidden_int4, num_topk, num_experts, num_scales, \ buffer_ptrs, rank, \ num_max_send_tokens, num_recv_buffer_tokens); \ } break @@ -493,7 +506,7 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int // Clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); + auto ptr = static_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; @@ -590,7 +603,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count"); // Calculate pointers by the specific layout - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[send_rank_id])); + auto ptr = reinterpret_cast(static_cast(buffer_ptrs[send_rank_id])); auto num_channels_total = num_channels * kNumRanks; auto channel_rank_offset = responsible_channel * kNumRanks + rank; @@ -682,7 +695,7 @@ combine(dtype_t* recv_x, float* recv_topk_weights, asm volatile("bar.sync 0, %0;" :: "r"(kNumThreads)); if (thread_id < 32) { - int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id; + int* channel_head_idx_ptr = static_cast(buffer_ptrs[rank]) + responsible_channel * kNumRanks + lane_id; int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; // Queue head updater @@ -720,13 +733,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights, auto channel_rank_offset = responsible_channel * kNumRanks + i; auto num_channels_total = num_channels * kNumRanks; // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) - auto ptr = reinterpret_cast(reinterpret_cast(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int)); + auto ptr = reinterpret_cast(static_cast(buffer_ptrs[rank]) + 2 * num_channels * kNumRanks * sizeof(int)); // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * hidden_int4 * sizeof(int4) channel_x_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * hidden_int4, channel_rank_offset * num_recv_buffer_tokens * hidden_int4); // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * sizeof(int) - ptr = reinterpret_cast(reinterpret_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); + ptr = reinterpret_cast(static_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); // `topk_weights_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * num_topk * sizeof(float) channel_topk_weights_buffers[i] = Buffer(ptr, num_channels_total * num_recv_buffer_tokens * num_topk, channel_rank_offset * num_recv_buffer_tokens * num_topk); diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 5fadddd..439d17d 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -249,7 +249,8 @@ class Buffer: handle: Optional[Tuple] = None, num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None, - topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, + topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -276,6 +277,8 @@ class Buffer: `-1` means no selections. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. expert_alignment: align the number of tokens received by each local expert to this variable. + num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it + will be CUDA-graph compatible. Please also notice that this flag is for intranode only. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -296,6 +299,7 @@ class Buffer: # Internode if self.runtime.get_num_rdma_ranks() > 1: + assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`' return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream) @@ -308,14 +312,16 @@ class Buffer: recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + expert_alignment, num_worst_tokens, config, + getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \ self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights, - num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, + expert_alignment, num_worst_tokens, config, + getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event) diff --git a/tests/test_intranode.py b/tests/test_intranode.py index c069c6d..c59dc46 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -100,6 +100,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list if current_x is not x_pure_rand: check_data(recv_x, rank_prefix_matrix) + recv_topk_weights_clone = None if with_topk: # Check `topk_idx` assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() @@ -107,10 +108,26 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` + recv_topk_weights_clone = recv_topk_weights.clone() if current_x is not x_pure_rand: recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] check_data(recv_topk_weights, rank_prefix_matrix) + # Test `num_worst_tokens != 0` + if with_topk: + num_worst_tokens = num_tokens * num_ranks + dispatch_args.update({'num_worst_tokens': num_worst_tokens}) + recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x + assert num_worst_tokens == recv_worst_x.size(0) + assert num_worst_tokens == recv_worst_topk_idx.size(0) + assert num_worst_tokens == recv_worst_topk_weights.size(0) + assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)]) + assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)]) + assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)]) + assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item() + # Test cached dispatch (must without top-k staffs) if not with_topk: dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode}