This commit is contained in:
Dmytro Dzhulgakov 2025-03-13 00:42:08 +00:00
parent b3b61ef5ef
commit 50ac280ae7

View File

@ -498,7 +498,7 @@ class Buffer:
# noinspection PyTypeChecker
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, async_finish: bool = False, return_recv_hook: bool = False,
out: torch.Tensor | None = None) -> \
out: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.