From d52c72d5b2f2fb4c41afbf8eb21366820239913d Mon Sep 17 00:00:00 2001 From: ramichen Date: Mon, 24 Mar 2025 17:02:35 +0800 Subject: [PATCH] Fix missing device for pphy2mlog tensor --- eplb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eplb.py b/eplb.py index 6e254f3..26c3987 100644 --- a/eplb.py +++ b/eplb.py @@ -121,7 +121,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: i pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + - torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2) + torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, + device=group_pack_index.device).view(1, -1, 1)).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)