mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Allow passing output tensor in low_latency_combine
This commit is contained in:
@@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook) {
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
// Allocate output tensor
|
||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
torch::Tensor combined_x;
|
||||
if (out.has_value()) {
|
||||
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
|
||||
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
|
||||
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
|
||||
combined_x = out.value();
|
||||
} else {
|
||||
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
|
||||
Reference in New Issue
Block a user