mirror of
https://github.com/deepseek-ai/EPLB
synced 2025-06-26 18:15:49 +00:00
Initial commit
This commit is contained in:
commit
f9bc62e841
174
.gitignore
vendored
Normal file
174
.gitignore
vendored
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 DeepSeek
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
67
README.md
Normal file
67
README.md
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# Expert Parallelism Load Balancer (EPLB)
|
||||||
|
|
||||||
|
When using expert parallelism (EP), different experts are assigned to different GPUs. Because the load of different
|
||||||
|
experts may vary depending on the current workload, it is important to keep the load of different GPUs balanced.
|
||||||
|
As described in the DeepSeek-V3 paper, we adopt **redundant experts** strategy that duplicates heavy-loaded experts.
|
||||||
|
Then, we heuristically pack the duplicated experts to GPUs to ensure load balancing across different GPUs. Moreover,
|
||||||
|
thanks to the **group-limited expert routing** used in DeepSeek-V3, we also attempt to place the experts of the same
|
||||||
|
group to the same node to reduce inter-node data traffic, whenever possible.
|
||||||
|
|
||||||
|
To facilitate reproduction and deployment, we open-source our deployed EP load balancing algorithm in `eplb.py`.
|
||||||
|
The algorithm computes a balanced expert replication and placement plan based on the estimated expert loads. Note
|
||||||
|
that the exact method to predict the loads of experts is out of this repo's scope. A common method is to use
|
||||||
|
moving average of historical statistics.
|
||||||
|
|
||||||
|
## The Algorithm
|
||||||
|
|
||||||
|
The load balancing algorithm comes with two policies used for different cases.
|
||||||
|
|
||||||
|
### Hierarchical Load Balancing
|
||||||
|
|
||||||
|
When the number of server nodes divides the number of expert groups, we use the hierarchical load balancing policy to
|
||||||
|
harness the group-limited expert routing. We first pack the expert groups to nodes evenly, ensuring the loads of
|
||||||
|
different nodes are balanced. Then, we replicate the experts within each node. Finally, we pack the replicated experts
|
||||||
|
to individual GPUs to ensure different GPUs are load-balanced. The hierarchical load balancing policy can be used in
|
||||||
|
prefilling stage with a smaller expert-parallel size.
|
||||||
|
|
||||||
|
### Global Load Balancing
|
||||||
|
|
||||||
|
In other cases, we use the global load balancing policy that replicates the experts globally regardless of expert
|
||||||
|
groups, and pack the replicated experts to individual GPUs. This policy can be adopted in decoding stage with a larger
|
||||||
|
expert-parallel size.
|
||||||
|
|
||||||
|
## Interface and Example
|
||||||
|
|
||||||
|
The main function of the load balancer is `eplb.rebalance_experts`.
|
||||||
|
|
||||||
|
The following code illustrates an example of a two-layer MoE model, and each layer contains 12 experts. We introduce 4 redundant experts per layer, and the total 16 replicas are placed on 2 nodes, and each node contains 4 GPUs.
|
||||||
|
|
||||||
|
``` python
|
||||||
|
import torch
|
||||||
|
import eplb
|
||||||
|
|
||||||
|
weight = torch.tensor([[ 90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||||
|
[ 20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27]])
|
||||||
|
|
||||||
|
num_replicas = 16
|
||||||
|
num_groups = 4
|
||||||
|
num_nodes = 2
|
||||||
|
num_gpus = 8
|
||||||
|
|
||||||
|
phy2log, log2phy, logcnt = eplb.rebalance_experts(weight, num_replicas, num_groups, num_nodes, num_gpus)
|
||||||
|
print(phy2log)
|
||||||
|
|
||||||
|
# Output:
|
||||||
|
# tensor([[ 5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||||
|
# [ 7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1]])
|
||||||
|
```
|
||||||
|
|
||||||
|
The output, generated by the hierarchical load balancing policy, indicates the following
|
||||||
|
expert replication and placement plan.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This code repository is released under the MIT License.
|
161
eplb.py
Normal file
161
eplb.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||||
|
are as balanced as possible.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
weight: [X, n], the weight of each item
|
||||||
|
num_packs: number of packs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pack_index: [X, n], the pack index of each item
|
||||||
|
rank_in_pack: [X, n], the rank of the item in the pack
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_layers, num_groups = weight.shape
|
||||||
|
assert num_groups % num_packs == 0
|
||||||
|
groups_per_pack = num_groups // num_packs
|
||||||
|
|
||||||
|
if groups_per_pack == 1:
|
||||||
|
pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
|
||||||
|
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||||
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
|
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||||
|
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu')
|
||||||
|
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||||
|
for i in range(num_layers):
|
||||||
|
pack_weights = [0] * num_packs
|
||||||
|
pack_items = [0] * num_packs
|
||||||
|
for group in indices[i]:
|
||||||
|
pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||||
|
key=pack_weights.__getitem__)
|
||||||
|
assert pack_items[pack] < groups_per_pack
|
||||||
|
pack_index[i, group] = pack
|
||||||
|
rank_in_pack[i, group] = pack_items[pack]
|
||||||
|
pack_weights[pack] += weight[i, group]
|
||||||
|
pack_items[pack] += 1
|
||||||
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
|
|
||||||
|
def replicate_experts(weight: torch.Tensor, num_phy: int) -> torch.Tensor:
|
||||||
|
|
||||||
|
"""
|
||||||
|
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
weight: [X, num_log]
|
||||||
|
num_phy: total number of experts after replication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||||
|
rank: [X, num_phy], the duplica rank
|
||||||
|
logcnt: [X, num_log], number of replicas for each logical expert
|
||||||
|
"""
|
||||||
|
n, num_log = weight.shape
|
||||||
|
num_redundant = num_phy - num_log
|
||||||
|
assert num_redundant >= 0
|
||||||
|
device = weight.device
|
||||||
|
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||||
|
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||||
|
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||||
|
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||||
|
for i in range(num_log, num_phy):
|
||||||
|
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||||
|
phy2log[:, i] = redundant_indices
|
||||||
|
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||||
|
logcnt[arangen, redundant_indices] += 1
|
||||||
|
return phy2log, rank, logcnt
|
||||||
|
|
||||||
|
|
||||||
|
def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int,
|
||||||
|
num_groups: int, num_nodes: int, num_gpus: int):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
weight: [num_moe_layers, num_logical_experts]
|
||||||
|
group_size: number of logical experts per group, used in group-limited routing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||||
|
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||||
|
logical_count: [num_moe_layers, num_logical_experts]
|
||||||
|
"""
|
||||||
|
num_layers, num_logical_experts = weight.shape
|
||||||
|
assert num_logical_experts % num_groups == 0
|
||||||
|
group_size = num_logical_experts // num_groups
|
||||||
|
assert num_groups % num_nodes == 0
|
||||||
|
groups_per_node = num_groups // num_nodes
|
||||||
|
assert num_gpus % num_nodes == 0
|
||||||
|
assert num_physical_experts % num_gpus == 0
|
||||||
|
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||||
|
|
||||||
|
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||||
|
inv = torch.empty_like(perm)
|
||||||
|
inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
|
||||||
|
return inv
|
||||||
|
|
||||||
|
# Step 1: pack groups to nodes
|
||||||
|
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||||
|
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||||
|
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) +
|
||||||
|
torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2)
|
||||||
|
mlog2log = inverse(log2mlog)
|
||||||
|
|
||||||
|
# Step 2: construct redundant experts within nodes
|
||||||
|
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||||
|
tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
|
||||||
|
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
|
||||||
|
|
||||||
|
# Step 3: pack physical_experts to GPUs
|
||||||
|
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||||
|
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||||
|
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||||
|
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||||
|
pphy2phy = inverse(phy2pphy)
|
||||||
|
|
||||||
|
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
|
||||||
|
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) +
|
||||||
|
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2)
|
||||||
|
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||||
|
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||||
|
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||||
|
return pphy2log, pphyrank, logcnt
|
||||||
|
|
||||||
|
def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
|
||||||
|
num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Entry point for expert-parallelism load balancer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||||
|
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||||
|
num_groups: number of expert groups
|
||||||
|
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||||
|
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||||
|
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||||
|
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||||
|
"""
|
||||||
|
num_layers, num_logical_experts = weight.shape
|
||||||
|
weight = weight.float().cpu()
|
||||||
|
if num_groups % num_nodes == 0:
|
||||||
|
# use hierarchical load-balance policy
|
||||||
|
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas,
|
||||||
|
num_groups, num_nodes, num_gpus)
|
||||||
|
else:
|
||||||
|
# use global load-balance policy
|
||||||
|
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
|
||||||
|
maxlogcnt = logcnt.max().item()
|
||||||
|
log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt),
|
||||||
|
-1, dtype=torch.int64, device=logcnt.device)
|
||||||
|
log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank,
|
||||||
|
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1))
|
||||||
|
return phy2log, log2phy, logcnt
|
||||||
|
|
||||||
|
__all__ = ['rebalance_experts']
|
BIN
example.png
Normal file
BIN
example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
Loading…
Reference in New Issue
Block a user