Allow passing output tensor in low_latency_combine

This commit is contained in:
Dmytro Dzhulgakov 2025-03-10 22:19:21 +00:00
parent ed7487c15e
commit b3b61ef5ef
4 changed files with 17 additions and 6 deletions

View File

@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) {
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
stream_wait(launch_stream, compute_stream);
// Allocate output tensor
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();

View File

@ -143,7 +143,7 @@ public:
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook);
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
};
} // namespace deep_ep

View File

@ -497,7 +497,8 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
out: torch.Tensor | None = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@ -520,6 +521,7 @@ class Buffer:
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you not set this flag, the kernel will ensure the data's arrival.
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
Returns:
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
@ -529,6 +531,6 @@ class Buffer:
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
num_max_dispatch_tokens_per_rank, num_experts,
async_finish, return_recv_hook)
async_finish, return_recv_hook, out)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

View File

@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)