Allow passing output tensor in low_latency_combine

This commit is contained in:
Dmytro Dzhulgakov 2025-03-10 22:19:21 +00:00
parent ed7487c15e
commit b3b61ef5ef
4 changed files with 17 additions and 6 deletions

View File

@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) { bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
EP_HOST_ASSERT(low_latency_mode); EP_HOST_ASSERT(low_latency_mode);
// Tensor checks // Tensor checks
@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
stream_wait(launch_stream, compute_stream); stream_wait(launch_stream, compute_stream);
// Allocate output tensor // Allocate output tensor
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch // Kernel launch
auto next_clean_meta = next_buffer.clean_meta(); auto next_clean_meta = next_buffer.clean_meta();

View File

@ -143,7 +143,7 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range, const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts, int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook); bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
}; };
} // namespace deep_ep } // namespace deep_ep

View File

@ -497,7 +497,8 @@ class Buffer:
# noinspection PyTypeChecker # noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \ handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
out: torch.Tensor | None = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]: Tuple[torch.Tensor, EventOverlap, Callable]:
""" """
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@ -520,6 +521,7 @@ class Buffer:
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival. If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns: Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
@ -529,6 +531,6 @@ class Buffer:
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts, num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook) async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

View File

@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness # Check combine correctness
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait() hook() if return_recv_hook else event.current_stream_wait()
if do_check: if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)