mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-06-26 18:16:46 +00:00
Add Moore Threads MUSA backend
This commit is contained in:
parent
ebe10fcefe
commit
a96f82988e
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@ build
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
dist/
|
||||
output/*
|
||||
*.log
|
||||
31
LICENSE
31
LICENSE
@ -1,3 +1,34 @@
|
||||
The MT-DualPipe from Moore Threads is licensed under the MIT License listed below.
|
||||
Copyright (c) 2025 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
Terms of the MIT License
|
||||
-------------------------------------------------------------------------
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
-------------------------------------------------------------------------
|
||||
The following copyright statements and licenses apply to various open source software/model
|
||||
packages (or portions thereof) that are distributed with this MT-DualPipe. MT-DualPipe that
|
||||
includes this file does not necessarily use all the open source software packages referred
|
||||
to below and may also only use portions of a given package. Some open source software
|
||||
packages referred to below may have been modified by Moore Threads Technology Co., Ltd
|
||||
-------------------------------------------------------------------------
|
||||
DualPipe
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
# MT-DualPipe
|
||||
|
||||
This repository is forked from the open source project DualPipe [deepseek-ai/DualPipe](https://github.com/deepseek-ai/DualPipe). It enables dual pipeline training on Moore Threads GPUs using the PyTorch MUSA backend [torch_musa](https://github.com/MooreThreads/torch_musa).
|
||||
|
||||
# DualPipe
|
||||
|
||||
DualPipe is an innovative bidirectional pipeline parallelism algorithm introduced in the [DeepSeek-V3 Technical Report](https://arxiv.org/pdf/2412.19437). It achieves full overlap of forward and backward computation-communication phases, also reducing pipeline bubbles. For detailed information on computation-communication overlap, please refer to the [profile data](https://github.com/deepseek-ai/profile-data).
|
||||
|
||||
43
example.py
43
example.py
@ -2,6 +2,12 @@ from typing import List, Optional, Callable, Tuple
|
||||
import os
|
||||
|
||||
import torch
|
||||
try:
|
||||
import torch_musa
|
||||
from musa_patch import patch
|
||||
patch()
|
||||
except:
|
||||
pass
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
@ -106,12 +112,23 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
|
||||
return cos_diff
|
||||
|
||||
|
||||
def main(rank, pp_size):
|
||||
def main(pp_size=8):
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
rank = int(os.environ['RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
master_addr = os.environ.get('MASTER_ADDR')
|
||||
master_port = os.environ.get('MASTER_PORT')
|
||||
|
||||
is_first_rank = rank == 0
|
||||
is_last_rank = rank == pp_size - 1
|
||||
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
torch.set_default_device(f"cuda:{rank}")
|
||||
# dist.init_process_group(backend='mccl', init_method="env://", world_size=pp_size, rank=rank)
|
||||
dist.init_process_group(backend='mccl',
|
||||
init_method='tcp://' + master_addr + ':' + master_port,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
# torch.set_default_device(f"cuda:{local_rank}")
|
||||
torch.manual_seed(233)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
@ -125,18 +142,18 @@ def main(rank, pp_size):
|
||||
set_p2p_tensor_dtype(torch.float32)
|
||||
|
||||
# Create a model and partition it for each process
|
||||
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)])
|
||||
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)]).cuda()
|
||||
|
||||
# Full inputs
|
||||
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size).cuda()
|
||||
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size).cuda()
|
||||
|
||||
# Reference step
|
||||
loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks)
|
||||
|
||||
# DualPipe
|
||||
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size - 1 - rank])
|
||||
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
|
||||
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size)).cuda()
|
||||
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
|
||||
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
|
||||
dualpipe_model = DualPipe(local_modules)
|
||||
@ -166,14 +183,14 @@ def main(rank, pp_size):
|
||||
|
||||
# Check grads
|
||||
for (p0, p1) in zip(local_modules[0].parameters(), local_modules[1].parameters()):
|
||||
p0all = torch.empty(pp_size, *p0.shape)
|
||||
p1all = torch.empty(pp_size, *p1.shape)
|
||||
p0all = torch.empty(pp_size, *p0.shape).cuda()
|
||||
p1all = torch.empty(pp_size, *p1.shape).cuda()
|
||||
dist.all_gather_into_tensor(p0all, p0.grad)
|
||||
dist.all_gather_into_tensor(p1all, p1.grad)
|
||||
p0.grad += p1all[pp_size - 1 - rank]
|
||||
p1.grad += p0all[pp_size - 1 - rank]
|
||||
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
|
||||
assert cal_diff(p.grad, p_ref.grad) < 1e-13
|
||||
assert cal_diff(p.grad, p_ref.grad) < 1e-7
|
||||
dualpipe_model.zero_grad()
|
||||
|
||||
# Inference step
|
||||
@ -191,12 +208,10 @@ def main(rank, pp_size):
|
||||
assert loss is None
|
||||
assert outputs is None
|
||||
|
||||
|
||||
def test_dualpipe(ngpus):
|
||||
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_gpus = torch.cuda.device_count() // 2 * 2
|
||||
for ngpus in range(num_gpus, 0, -2):
|
||||
test_dualpipe(ngpus)
|
||||
main(8)
|
||||
|
||||
45
musa_patch.py
Normal file
45
musa_patch.py
Normal file
@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torch_musa
|
||||
|
||||
def patch():
|
||||
print('musa patch is working')
|
||||
def _pass_pvtx(*args, **kwargs):
|
||||
return
|
||||
torch.cuda.nvtx.range_push = _pass_pvtx
|
||||
torch.cuda.nvtx.range_pop = _pass_pvtx
|
||||
torch.cuda.nvtx.range = _pass_pvtx
|
||||
torch.cuda.is_available = torch.musa.is_available
|
||||
torch.cuda.current_device = lambda : f'musa:{torch.musa.current_device()}'
|
||||
torch.cuda.device_count = torch.musa.device_count
|
||||
torch.cuda.set_device = torch.musa.set_device
|
||||
torch.cuda.DoubleTensor = torch.musa.DoubleTensor
|
||||
torch.cuda.FloatTensor = torch.musa.FloatTensor
|
||||
torch.cuda.LongTensor = torch.musa.LongTensor
|
||||
torch.cuda.HalfTensor = torch.musa.HalfTensor
|
||||
torch.cuda.BFloat16Tensor = torch.musa.BFloat16Tensor
|
||||
torch.cuda.IntTensor = torch.musa.IntTensor
|
||||
torch.cuda.synchronize = torch.musa.synchronize
|
||||
torch.cuda.get_rng_state = torch.musa.get_rng_state
|
||||
torch.cuda.set_rng_state = torch.musa.set_rng_state
|
||||
torch.cuda.synchronize = torch.musa.synchronize
|
||||
torch.cuda.empty_cache = torch.musa.empty_cache
|
||||
torch.Tensor.cuda = torch.Tensor.musa
|
||||
torch.cuda.manual_seed = torch.musa.manual_seed
|
||||
torch.cuda.Event = torch.musa.Event
|
||||
torch.cuda.Stream = torch.musa.Stream
|
||||
torch.cuda.get_device_properties = torch.musa.get_device_properties
|
||||
# Memory
|
||||
torch.cuda.memory_allocated = torch.musa.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.musa.memory_allocated
|
||||
torch.cuda.memory_reserved = torch.musa.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.musa.max_memory_reserved
|
||||
|
||||
original_empty = torch.empty
|
||||
def patched_empty(*args, **kwargs):
|
||||
if 'device' in kwargs and kwargs['device'] == 'cuda':
|
||||
kwargs['device'] = 'musa'
|
||||
result = original_empty(*args, **kwargs)
|
||||
return result
|
||||
torch.empty = patched_empty
|
||||
|
||||
torch.Tensor.double = torch.Tensor.float
|
||||
32
start.sh
Normal file
32
start.sh
Normal file
@ -0,0 +1,32 @@
|
||||
export OMP_NUM_THREADS=4
|
||||
export MUSA_LAUNCH_BLOCKING=1
|
||||
export MCCL_ALGOS=1
|
||||
export OMP_NUM_THREADS=4
|
||||
export MUSA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
|
||||
export MUSA_KERNEL_TIMEOUT=3200000
|
||||
export MCCL_PROTOS=2
|
||||
export MCCL_CHECK_POINTERS=0
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
||||
export MCCL_IB_GID_INDEX=3
|
||||
export MUSA_BLOCK_SCHEDULE_MODE=1
|
||||
# export MCCL_BUFFSIZE=20480000
|
||||
WORK_HOME="$PWD"
|
||||
CURRENT_TIME=$(date "+%Y-%m-%d_%H:%M:%S")
|
||||
echo $CURRENT_TIME
|
||||
|
||||
DISTRIBUTED_ARGS=(
|
||||
--nproc_per_node 8
|
||||
--nnodes 1
|
||||
--node_rank 0
|
||||
--master_addr 127.0.0.1
|
||||
--master_port 12345
|
||||
--log_dir $WORK_HOME/output/$CURRENT_TIME
|
||||
--redirects 3
|
||||
)
|
||||
|
||||
cmd="PYTHONPATH=$PYTHONPATH:./dualpipe torchrun \
|
||||
${DISTRIBUTED_ARGS[@]}
|
||||
example.py
|
||||
"
|
||||
echo $cmd
|
||||
eval $cmd
|
||||
Loading…
Reference in New Issue
Block a user