add gpu-level load balance for global policy

close #14
This commit is contained in:
Shaoyuan CHEN 2025-03-21 17:11:10 +08:00
parent 3079d71e1f
commit e1100fefe7

View File

@ -15,7 +15,6 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
@ -153,7 +152,7 @@ def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt),
-1, dtype=torch.int64, device=logcnt.device)