diff --git a/eplb.py b/eplb.py index 00c1e50..89a231a 100644 --- a/eplb.py +++ b/eplb.py @@ -53,7 +53,7 @@ def replicate_experts(weight: torch.Tensor, num_phy: int) -> torch.Tensor: Returns: phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the duplica rank + rank: [X, num_phy], the replica rank logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape @@ -77,7 +77,10 @@ def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: i """ Parameters: weight: [num_moe_layers, num_logical_experts] - group_size: number of logical experts per group, used in group-limited routing + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: physical_to_logical_map: [num_moe_layers, num_physical_experts]