From b3b61ef5efe4d9a13c455bdf770f4a8e9c91df4a Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Mon, 10 Mar 2025 22:19:21 +0000 Subject: [PATCH] Allow passing output tensor in low_latency_combine --- csrc/deep_ep.cpp | 12 ++++++++++-- csrc/deep_ep.hpp | 2 +- deep_ep/buffer.py | 6 ++++-- tests/test_low_latency.py | 3 ++- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index d43f3e5..404c2b4 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1100,7 +1100,7 @@ std::tuple, std::optional out) { EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id stream_wait(launch_stream, compute_stream); // Allocate output tensor - auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + torch::Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); + EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); + combined_x = out.value(); + } else { + combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 3fff3ae..1a5bb79 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -143,7 +143,7 @@ public: low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, int num_max_dispatch_tokens_per_rank, int num_experts, - bool async, bool return_recv_hook); + bool async, bool return_recv_hook, std::optional out = std::nullopt); }; } // namespace deep_ep diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 97a5b90..2d7a0e6 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -497,7 +497,8 @@ class Buffer: # noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \ + handle: tuple, async_finish: bool = False, return_recv_hook: bool = False, + out: torch.Tensor | None = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. @@ -520,6 +521,7 @@ class Buffer: return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival. + out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. Returns: combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. @@ -529,6 +531,6 @@ class Buffer: src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts, - async_finish, return_recv_hook) + async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index a375e25..3ed681b 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) # Check combine correctness + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out) hook() if return_recv_hook else event.current_stream_wait() if do_check: diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)