diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index a51eba8..c5aaed3 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1030,6 +1030,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook) { EP_HOST_ASSERT(low_latency_mode); @@ -1042,6 +1043,11 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); EP_HOST_ASSERT(num_experts % num_ranks == 0); + if (cumulative_local_expert_recv_stats.has_value()) { + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); + } auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); @@ -1084,6 +1090,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), + cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 3d6aab4..f193bcc 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -142,6 +142,7 @@ public: std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool async, bool return_recv_hook); diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index b08cbab..1ffa061 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -133,6 +133,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, + int* cumulative_local_expert_recv_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index fd80ad2..1452c8e 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -41,6 +41,7 @@ __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, + int* cumulative_local_expert_recv_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, @@ -273,6 +274,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + if (cumulative_local_expert_recv_stats != nullptr) + atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); } asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; @@ -310,6 +313,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, + int* cumulative_local_expert_recv_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, @@ -338,6 +342,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \ packed_recv_x, packed_recv_x_scales, \ packed_recv_src_info, packed_recv_layout_range, \ packed_recv_count, \ + cumulative_local_expert_recv_stats, \ rdma_recv_x, rdma_recv_count, rdma_x, \ x, topk_idx, \ atomic_counter_per_expert, atomic_finish_counter_per_expert, \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index a403589..5fadddd 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -473,6 +473,7 @@ class Buffer: # noinspection PyTypeChecker def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: """ @@ -481,7 +482,7 @@ class Buffer: (specifically, IBGDA must be enabled). Even for ranks in the same node, NVLink are fully disabled for simplicity. Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 - low-latency kernels' result tensor at a single moment. + low-latency kernels' result tensors at a single moment. Arguments: x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are @@ -490,6 +491,9 @@ class Buffer: are supported. `-1` indices (not selecting any expert) are supported. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_experts: the number of all experts. + cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape + `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance + monitoring. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, @@ -508,19 +512,21 @@ class Buffer: Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each - expert receive. As mentioned before, not all tokens are valid in `recv_x`. + expert receives. As mentioned before, not all tokens are valid in `recv_x`. handle: the communication handle to be used in the `low_latency_combine` function. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ self.runtime.low_latency_dispatch(x, topk_idx, + cumulative_local_expert_recv_stats, num_max_dispatch_tokens_per_rank, num_experts, use_fp8, async_finish, return_recv_hook) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, - packed_recv_src_info, packed_recv_layout_range) + packed_recv_src_info, packed_recv_layout_range, + cumulative_local_expert_recv_stats) return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \ EventOverlap(event, tensors_to_record if async_finish else None), hook diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 9805263..0719422 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -36,8 +36,10 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, for dispatch_use_fp8 in (False, True): num_times += 1 for i in range((num_times % 2) + 1): + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') packed_recv_x, packed_recv_count, handle, event, hook = \ buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) hook() if return_recv_hook else event.current_stream_wait() packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x @@ -53,6 +55,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, # Check expert indices int_mask = (2 ** 32) - 1 num_valid_tokens = recv_count.item() + assert cumulative_local_expert_recv_stats[i].item() == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' @@ -108,6 +111,7 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, def test_func(zero_copy: bool, return_recv_hook: bool): recv_x, recv_count, handle, event, hook = \ buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None if zero_copy: