diff --git a/eplb.py b/eplb.py index 6e254f3..26c3987 100644 --- a/eplb.py +++ b/eplb.py @@ -121,7 +121,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: i pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + - torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, + device=group_pack_index.device).view(1, -1, 1)).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)