mirror of
https://github.com/deepseek-ai/DeepEP
synced 2025-06-26 18:28:11 +00:00
Allow passing output tensor in low_latency_combine
This commit is contained in:
@@ -497,7 +497,8 @@ class Buffer:
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False) -> \
|
||||
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
|
||||
out: torch.Tensor | None = None) -> \
|
||||
Tuple[torch.Tensor, EventOverlap, Callable]:
|
||||
"""
|
||||
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
|
||||
@@ -520,6 +521,7 @@ class Buffer:
|
||||
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
|
||||
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
|
||||
If you not set this flag, the kernel will ensure the data's arrival.
|
||||
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
|
||||
|
||||
Returns:
|
||||
combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`.
|
||||
@@ -529,6 +531,6 @@ class Buffer:
|
||||
src_info, layout_range, num_max_dispatch_tokens_per_rank, num_experts = handle
|
||||
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
|
||||
num_max_dispatch_tokens_per_rank, num_experts,
|
||||
async_finish, return_recv_hook)
|
||||
async_finish, return_recv_hook, out)
|
||||
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
|
||||
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
|
||||
|
||||
Reference in New Issue
Block a user