function return match

replicate_experts correctly states it returns three tensors
This commit is contained in:
A-transformer 2025-02-27 11:33:14 +04:00 committed by GitHub
parent f9bc62e841
commit 48c86cb16b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,7 +42,7 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor
return pack_index, rank_in_pack
def replicate_experts(weight: torch.Tensor, num_phy: int) -> torch.Tensor:
def replicate_experts(weight: torch.Tensor, num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.