mirror of
https://github.com/deepseek-ai/EPLB
synced 2025-05-03 19:51:22 +00:00
parent
3079d71e1f
commit
e1100fefe7
3
eplb.py
3
eplb.py
@ -15,7 +15,6 @@ def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor
|
|||||||
pack_index: [X, n], the pack index of each item
|
pack_index: [X, n], the pack index of each item
|
||||||
rank_in_pack: [X, n], the rank of the item in the pack
|
rank_in_pack: [X, n], the rank of the item in the pack
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_layers, num_groups = weight.shape
|
num_layers, num_groups = weight.shape
|
||||||
assert num_groups % num_packs == 0
|
assert num_groups % num_packs == 0
|
||||||
groups_per_pack = num_groups // num_packs
|
groups_per_pack = num_groups // num_packs
|
||||||
@ -153,7 +152,7 @@ def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
|
|||||||
num_groups, num_nodes, num_gpus)
|
num_groups, num_nodes, num_gpus)
|
||||||
else:
|
else:
|
||||||
# use global load-balance policy
|
# use global load-balance policy
|
||||||
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
|
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus)
|
||||||
maxlogcnt = logcnt.max().item()
|
maxlogcnt = logcnt.max().item()
|
||||||
log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt),
|
log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt),
|
||||||
-1, dtype=torch.int64, device=logcnt.device)
|
-1, dtype=torch.int64, device=logcnt.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user