diff --git a/eplb.py b/eplb.py index d8cdbf0..6e254f3 100644 --- a/eplb.py +++ b/eplb.py @@ -15,7 +15,6 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor pack_index: [X, n], the pack index of each item rank_in_pack: [X, n], the rank of the item in the pack """ - num_layers, num_groups = weight.shape assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs @@ -153,7 +152,7 @@ def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int, num_groups, num_nodes, num_gpus) else: # use global load-balance policy - phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas) + phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus) maxlogcnt = logcnt.max().item() log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt), -1, dtype=torch.int64, device=logcnt.device)