mirror of
https://github.com/deepseek-ai/EPLB
synced 2025-04-18 13:14:48 +00:00
Initial commit
This commit is contained in:
commit
f9bc62e841
174
.gitignore
vendored
Normal file
174
.gitignore
vendored
Normal file
@ -0,0 +1,174 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
67
README.md
Normal file
67
README.md
Normal file
@ -0,0 +1,67 @@
|
||||
# Expert Parallelism Load Balancer (EPLB)
|
||||
|
||||
When using expert parallelism (EP), different experts are assigned to different GPUs. Because the load of different
|
||||
experts may vary depending on the current workload, it is important to keep the load of different GPUs balanced.
|
||||
As described in the DeepSeek-V3 paper, we adopt **redundant experts** strategy that duplicates heavy-loaded experts.
|
||||
Then, we heuristically pack the duplicated experts to GPUs to ensure load balancing across different GPUs. Moreover,
|
||||
thanks to the **group-limited expert routing** used in DeepSeek-V3, we also attempt to place the experts of the same
|
||||
group to the same node to reduce inter-node data traffic, whenever possible.
|
||||
|
||||
To facilitate reproduction and deployment, we open-source our deployed EP load balancing algorithm in `eplb.py`.
|
||||
The algorithm computes a balanced expert replication and placement plan based on the estimated expert loads. Note
|
||||
that the exact method to predict the loads of experts is out of this repo's scope. A common method is to use
|
||||
moving average of historical statistics.
|
||||
|
||||
## The Algorithm
|
||||
|
||||
The load balancing algorithm comes with two policies used for different cases.
|
||||
|
||||
### Hierarchical Load Balancing
|
||||
|
||||
When the number of server nodes divides the number of expert groups, we use the hierarchical load balancing policy to
|
||||
harness the group-limited expert routing. We first pack the expert groups to nodes evenly, ensuring the loads of
|
||||
different nodes are balanced. Then, we replicate the experts within each node. Finally, we pack the replicated experts
|
||||
to individual GPUs to ensure different GPUs are load-balanced. The hierarchical load balancing policy can be used in
|
||||
prefilling stage with a smaller expert-parallel size.
|
||||
|
||||
### Global Load Balancing
|
||||
|
||||
In other cases, we use the global load balancing policy that replicates the experts globally regardless of expert
|
||||
groups, and pack the replicated experts to individual GPUs. This policy can be adopted in decoding stage with a larger
|
||||
expert-parallel size.
|
||||
|
||||
## Interface and Example
|
||||
|
||||
The main function of the load balancer is `eplb.rebalance_experts`.
|
||||
|
||||
The following code illustrates an example of a two-layer MoE model, and each layer contains 12 experts. We introduce 4 redundant experts per layer, and the total 16 replicas are placed on 2 nodes, and each node contains 4 GPUs.
|
||||
|
||||
``` python
|
||||
import torch
|
||||
import eplb
|
||||
|
||||
weight = torch.tensor([[ 90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86],
|
||||
[ 20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27]])
|
||||
|
||||
num_replicas = 16
|
||||
num_groups = 4
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = eplb.rebalance_experts(weight, num_replicas, num_groups, num_nodes, num_gpus)
|
||||
print(phy2log)
|
||||
|
||||
# Output:
|
||||
# tensor([[ 5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1],
|
||||
# [ 7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1]])
|
||||
```
|
||||
|
||||
The output, generated by the hierarchical load balancing policy, indicates the following
|
||||
expert replication and placement plan.
|
||||
|
||||

|
||||
|
||||
|
||||
## License
|
||||
|
||||
This code repository is released under the MIT License.
|
161
eplb.py
Normal file
161
eplb.py
Normal file
@ -0,0 +1,161 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
def balanced_packing(weight: torch.Tensor, num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||
are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device='cpu')
|
||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
key=pack_weights.__getitem__)
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index[i, group] = pack
|
||||
rank_in_pack[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight[i, group]
|
||||
pack_items[pack] += 1
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(weight: torch.Tensor, num_phy: int) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the duplica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(weight: torch.Tensor, num_physical_experts: int,
|
||||
num_groups: int, num_nodes: int, num_gpus: int):
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
group_size: number of logical experts per group, used in group-limited routing
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||
logical_count: [num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1) +
|
||||
torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) +
|
||||
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
def rebalance_experts(weight: torch.Tensor, num_replicas: int, num_groups: int,
|
||||
num_nodes: int, num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float().cpu()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas,
|
||||
num_groups, num_nodes, num_gpus)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = replicate_experts(weight, num_replicas)
|
||||
maxlogcnt = logcnt.max().item()
|
||||
log2phy: torch.Tensor = torch.full((num_layers, num_logical_experts, maxlogcnt),
|
||||
-1, dtype=torch.int64, device=logcnt.device)
|
||||
log2phy.view(num_layers, -1).scatter_(-1, phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1))
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
__all__ = ['rebalance_experts']
|
BIN
example.png
Normal file
BIN
example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
Loading…
Reference in New Issue
Block a user