Add Moore Threads MUSA backend

This commit is contained in:
Zhi Chen 2025-02-27 17:59:21 +08:00
parent ebe10fcefe
commit a96f82988e
7 changed files with 144 additions and 14 deletions

2
.gitignore vendored
View File

@ -2,3 +2,5 @@ build
*.egg-info/
__pycache__/
dist/
output/*
*.log

31
LICENSE
View File

@ -1,3 +1,34 @@
The MT-DualPipe from Moore Threads is licensed under the MIT License listed below.
Copyright (c) 2025 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
Terms of the MIT License
-------------------------------------------------------------------------
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-------------------------------------------------------------------------
The following copyright statements and licenses apply to various open source software/model
packages (or portions thereof) that are distributed with this MT-DualPipe. MT-DualPipe that
includes this file does not necessarily use all the open source software packages referred
to below and may also only use portions of a given package. Some open source software
packages referred to below may have been modified by Moore Threads Technology Co., Ltd
-------------------------------------------------------------------------
DualPipe
MIT License
Copyright (c) 2025 DeepSeek

View File

@ -1,3 +1,7 @@
# MT-DualPipe
This repository is forked from the open source project DualPipe [deepseek-ai/DualPipe](https://github.com/deepseek-ai/DualPipe). It enables dual pipeline training on Moore Threads GPUs using the PyTorch MUSA backend [torch_musa](https://github.com/MooreThreads/torch_musa).
# DualPipe
DualPipe is an innovative bidirectional pipeline parallelism algorithm introduced in the [DeepSeek-V3 Technical Report](https://arxiv.org/pdf/2412.19437). It achieves full overlap of forward and backward computation-communication phases, also reducing pipeline bubbles. For detailed information on computation-communication overlap, please refer to the [profile data](https://github.com/deepseek-ai/profile-data).

View File

@ -2,6 +2,12 @@ from typing import List, Optional, Callable, Tuple
import os
import torch
try:
import torch_musa
from musa_patch import patch
patch()
except:
pass
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
@ -106,12 +112,23 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
return cos_diff
def main(rank, pp_size):
def main(pp_size=8):
local_rank = int(os.environ['LOCAL_RANK'])
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
master_addr = os.environ.get('MASTER_ADDR')
master_port = os.environ.get('MASTER_PORT')
is_first_rank = rank == 0
is_last_rank = rank == pp_size - 1
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
torch.cuda.set_device(rank)
torch.set_default_device(f"cuda:{rank}")
# dist.init_process_group(backend='mccl', init_method="env://", world_size=pp_size, rank=rank)
dist.init_process_group(backend='mccl',
init_method='tcp://' + master_addr + ':' + master_port,
rank=rank,
world_size=world_size)
torch.cuda.set_device(local_rank)
# torch.set_default_device(f"cuda:{local_rank}")
torch.manual_seed(233)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@ -125,18 +142,18 @@ def main(rank, pp_size):
set_p2p_tensor_dtype(torch.float32)
# Create a model and partition it for each process
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)])
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)]).cuda()
# Full inputs
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size).cuda()
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size).cuda()
# Reference step
loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks)
# DualPipe
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size - 1 - rank])
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size)).cuda()
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
dualpipe_model = DualPipe(local_modules)
@ -166,14 +183,14 @@ def main(rank, pp_size):
# Check grads
for (p0, p1) in zip(local_modules[0].parameters(), local_modules[1].parameters()):
p0all = torch.empty(pp_size, *p0.shape)
p1all = torch.empty(pp_size, *p1.shape)
p0all = torch.empty(pp_size, *p0.shape).cuda()
p1all = torch.empty(pp_size, *p1.shape).cuda()
dist.all_gather_into_tensor(p0all, p0.grad)
dist.all_gather_into_tensor(p1all, p1.grad)
p0.grad += p1all[pp_size - 1 - rank]
p1.grad += p0all[pp_size - 1 - rank]
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
assert cal_diff(p.grad, p_ref.grad) < 1e-13
assert cal_diff(p.grad, p_ref.grad) < 1e-7
dualpipe_model.zero_grad()
# Inference step
@ -191,12 +208,10 @@ def main(rank, pp_size):
assert loss is None
assert outputs is None
def test_dualpipe(ngpus):
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
if __name__ == "__main__":
num_gpus = torch.cuda.device_count() // 2 * 2
for ngpus in range(num_gpus, 0, -2):
test_dualpipe(ngpus)
main(8)

1
hostfile Normal file
View File

@ -0,0 +1 @@
127.0.0.1

45
musa_patch.py Normal file
View File

@ -0,0 +1,45 @@
import torch
import torch_musa
def patch():
print('musa patch is working')
def _pass_pvtx(*args, **kwargs):
return
torch.cuda.nvtx.range_push = _pass_pvtx
torch.cuda.nvtx.range_pop = _pass_pvtx
torch.cuda.nvtx.range = _pass_pvtx
torch.cuda.is_available = torch.musa.is_available
torch.cuda.current_device = lambda : f'musa:{torch.musa.current_device()}'
torch.cuda.device_count = torch.musa.device_count
torch.cuda.set_device = torch.musa.set_device
torch.cuda.DoubleTensor = torch.musa.DoubleTensor
torch.cuda.FloatTensor = torch.musa.FloatTensor
torch.cuda.LongTensor = torch.musa.LongTensor
torch.cuda.HalfTensor = torch.musa.HalfTensor
torch.cuda.BFloat16Tensor = torch.musa.BFloat16Tensor
torch.cuda.IntTensor = torch.musa.IntTensor
torch.cuda.synchronize = torch.musa.synchronize
torch.cuda.get_rng_state = torch.musa.get_rng_state
torch.cuda.set_rng_state = torch.musa.set_rng_state
torch.cuda.synchronize = torch.musa.synchronize
torch.cuda.empty_cache = torch.musa.empty_cache
torch.Tensor.cuda = torch.Tensor.musa
torch.cuda.manual_seed = torch.musa.manual_seed
torch.cuda.Event = torch.musa.Event
torch.cuda.Stream = torch.musa.Stream
torch.cuda.get_device_properties = torch.musa.get_device_properties
# Memory
torch.cuda.memory_allocated = torch.musa.memory_allocated
torch.cuda.max_memory_allocated = torch.musa.memory_allocated
torch.cuda.memory_reserved = torch.musa.memory_reserved
torch.cuda.max_memory_reserved = torch.musa.max_memory_reserved
original_empty = torch.empty
def patched_empty(*args, **kwargs):
if 'device' in kwargs and kwargs['device'] == 'cuda':
kwargs['device'] = 'musa'
result = original_empty(*args, **kwargs)
return result
torch.empty = patched_empty
torch.Tensor.double = torch.Tensor.float

32
start.sh Normal file
View File

@ -0,0 +1,32 @@
export OMP_NUM_THREADS=4
export MUSA_LAUNCH_BLOCKING=1
export MCCL_ALGOS=1
export OMP_NUM_THREADS=4
export MUSA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
export MUSA_KERNEL_TIMEOUT=3200000
export MCCL_PROTOS=2
export MCCL_CHECK_POINTERS=0
export CUDA_DEVICE_MAX_CONNECTIONS=1
export MCCL_IB_GID_INDEX=3
export MUSA_BLOCK_SCHEDULE_MODE=1
# export MCCL_BUFFSIZE=20480000
WORK_HOME="$PWD"
CURRENT_TIME=$(date "+%Y-%m-%d_%H:%M:%S")
echo $CURRENT_TIME
DISTRIBUTED_ARGS=(
--nproc_per_node 8
--nnodes 1
--node_rank 0
--master_addr 127.0.0.1
--master_port 12345
--log_dir $WORK_HOME/output/$CURRENT_TIME
--redirects 3
)
cmd="PYTHONPATH=$PYTHONPATH:./dualpipe torchrun \
${DISTRIBUTED_ARGS[@]}
example.py
"
echo $cmd
eval $cmd