Allow passing output tensor in low_latency_combine

This commit is contained in:
Dmytro Dzhulgakov
2025-03-10 22:19:21 +00:00
parent ed7487c15e
commit b3b61ef5ef
4 changed files with 17 additions and 6 deletions

View File

@@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool async, bool return_recv_hook) {
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
EP_HOST_ASSERT(low_latency_mode);
// Tensor checks
@@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
stream_wait(launch_stream, compute_stream);
// Allocate output tensor
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
torch::Tensor combined_x;
if (out.has_value()) {
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
combined_x = out.value();
} else {
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
}
// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();