mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Allow passing output tensor in low_latency_combine
This commit is contained in:
parent
ed7487c15e
commit
b3b61ef5ef
@ -1100,7 +1100,7 @@ std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::functio
|
||||
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook) {
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out) {
|
||||
EP_HOST_ASSERT(low_latency_mode);
|
||||
|
||||
// Tensor checks
|
||||
@ -1138,7 +1138,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
|
||||
stream_wait(launch_stream, compute_stream);
|
||||
|
||||
// Allocate output tensor
|
||||
auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
torch::Tensor combined_x;
|
||||
if (out.has_value()) {
|
||||
EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous());
|
||||
EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden);
|
||||
EP_HOST_ASSERT(out->scalar_type() == x.scalar_type());
|
||||
combined_x = out.value();
|
||||
} else {
|
||||
combined_x = torch::empty({num_combined_tokens, hidden}, x.options());
|
||||
}
|
||||
|
||||
// Kernel launch
|
||||
auto next_clean_meta = next_buffer.clean_meta();
|
||||
|
@ -143,7 +143,7 @@ public:
|
||||
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& src_info, const torch::Tensor& layout_range,
|
||||
int num_max_dispatch_tokens_per_rank, int num_experts,
|
||||
bool async, bool return_recv_hook);
|
||||
bool async, bool return_recv_hook, std::optional<torch::Tensor> out = std::nullopt);
|
||||
};
|
||||
|
||||
} // namespace deep_ep
|
||||
|
@ -497,7 +497,8 @@ class Buffer:
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
|
||||
out: torch.Tensor | None = None) -> \
|
||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
@ -520,6 +521,7 @@ class Buffer:
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
|
||||
|
||||
Returns:
|
||||
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
||||
@ -529,6 +531,6 @@ class Buffer:
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
async_finish, return_recv_hook, out)
|
||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
@ -73,8 +73,9 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int,
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
|
||||
combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle,
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook)
|
||||
async_finish=not return_recv_hook, return_recv_hook=return_recv_hook, out=out)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x)
|
||||
|
Loading…
Reference in New Issue
Block a user