From 48c86cb16b6c1789553dc03a92805906073e0eec Mon Sep 17 00:00:00 2001 From: A-transformer Date: Thu, 27 Feb 2025 11:33:14 +0400 Subject: [PATCH] function return match replicate_experts correctly states it returns three tensors --- eplb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eplb.py b/eplb.py index 00c1e50..3957da7 100644 --- a/eplb.py +++ b/eplb.py @@ -42,7 +42,7 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor return pack_index, rank_in_pack -def replicate_experts(weight: torch.Tensor, num_phy: int) -> torch.Tensor: +def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.