Fix missing device for pphy2mlog tensor

This commit is contained in:
ramichen 2025-03-24 17:02:35 +08:00 committed by Shaoyuan CHEN
parent e1100fefe7
commit d52c72d5b2

View File

@ -121,7 +121,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: i
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) +
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2)
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes,
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)