mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-04-04 04:40:43 +00:00
Initial commit
This commit is contained in:
commit
fbe0ac0d6e
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
build
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
dist/
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
56
README.md
Normal file
56
README.md
Normal file
@ -0,0 +1,56 @@
|
||||
# DualPipe
|
||||
|
||||
DualPipe is an innovative bidirectional pipeline parallism algorithm introduced in the [DeepSeek-V3 Technical Report](https://arxiv.org/pdf/2412.19437). It achieves full overlap of forward and backward computation-communication phases, also reducing pipeline bubbles. For detailed information on computation-communication overlap, please refer to the [profile data](https://github.com/deepseek-ai/profile-data).
|
||||
|
||||
### Schedules
|
||||
|
||||

|
||||
|
||||
Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions.
|
||||
The micro-batches in the reverse direction are symmetric to those in the forward direction, so
|
||||
we omit their batch ID for illustration simplicity. Two cells enclosed by a shared black border
|
||||
have mutually overlapped computation and communication
|
||||
|
||||
### Pipeline Bubbles and Memory Usage Comparison
|
||||
|
||||
| Method | Bubble | Parameter | Activation |
|
||||
|-------------|---------------------------------|-----------|------------|
|
||||
| 1F1B | (PP-1)(𝐹+𝑊) | 1× | PP |
|
||||
| ZB1P | (PP-1)(𝐹+𝐵-2𝑊) | 1× | PP |
|
||||
| DualPipe | (PP/2-1)(𝐹&𝐵+𝐵-3𝑊) | 2× | PP+1 |
|
||||
|
||||
𝐹 denotes the execution time of a forward chunk, 𝐵 denotes the execution time of a
|
||||
full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵
|
||||
denotes the execution time of two mutually overlapped forward and backward chunks.
|
||||
|
||||
## Quick Start
|
||||
|
||||
The usage is shown in the following example:
|
||||
|
||||
```bash
|
||||
python example.py
|
||||
```
|
||||
|
||||
Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module.
|
||||
|
||||
## Requirements
|
||||
|
||||
- PyTorch 2.0 and above
|
||||
|
||||
## Developers
|
||||
|
||||
DualPipe was created and developed by Jiashi Li and Chengqi Deng and Wenfeng Liang.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{deepseekai2024deepseekv3technicalreport,
|
||||
title={DeepSeek-V3 Technical Report},
|
||||
author={DeepSeek-AI},
|
||||
year={2024},
|
||||
eprint={2412.19437},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2412.19437},
|
||||
}
|
||||
```
|
17
dualpipe/__init__.py
Normal file
17
dualpipe/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
__version__ = "1.0.0"
|
||||
|
||||
from dualpipe.dualpipe import (
|
||||
DualPipe,
|
||||
WeightGradStore,
|
||||
)
|
||||
from dualpipe.comm import (
|
||||
set_p2p_tensor_shapes,
|
||||
set_p2p_tensor_dtype,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
DualPipe,
|
||||
WeightGradStore,
|
||||
set_p2p_tensor_shapes,
|
||||
set_p2p_tensor_dtype,
|
||||
]
|
38
dualpipe/comm.py
Normal file
38
dualpipe/comm.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
TENSOR_SHAPES: List[Tuple[int]] = None
|
||||
TENSOR_DTYPE: torch.dtype = None
|
||||
|
||||
|
||||
def set_p2p_tensor_shapes(shapes: List[Tuple[int]]):
|
||||
global TENSOR_SHAPES
|
||||
TENSOR_SHAPES = shapes
|
||||
|
||||
|
||||
def set_p2p_tensor_dtype(dtype: torch.dtype):
|
||||
global TENSOR_DTYPE
|
||||
TENSOR_DTYPE = dtype
|
||||
|
||||
|
||||
def build_from_tensor_shapes():
|
||||
return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES]
|
||||
|
||||
|
||||
def append_irecv(ops: List[dist.P2POp], src: int, group: dist.ProcessGroup) -> List[torch.Tensor]:
|
||||
tensors = build_from_tensor_shapes()
|
||||
src = dist.distributed_c10d.get_global_rank(group, src)
|
||||
for tensor in tensors:
|
||||
if tensor is not None:
|
||||
ops.append(dist.P2POp(dist.irecv, tensor, src))
|
||||
return tensors
|
||||
|
||||
|
||||
def append_isend(ops: List[dist.P2POp], tensors: List[torch.Tensor], dst: int, group: dist.ProcessGroup) -> None:
|
||||
dst = dist.distributed_c10d.get_global_rank(group, dst)
|
||||
for tensor in tensors:
|
||||
if tensor is not None:
|
||||
ops.append(dist.P2POp(dist.isend, tensor, dst))
|
440
dualpipe/dualpipe.py
Normal file
440
dualpipe/dualpipe.py
Normal file
@ -0,0 +1,440 @@
|
||||
from typing import Tuple, List, Union, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
import dualpipe.comm as comm
|
||||
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
|
||||
|
||||
|
||||
class DualPipe(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
modules: Tuple[nn.Module, nn.Module],
|
||||
batch_dim: int = 0,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rank_mapping: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
|
||||
self.module = nn.ModuleList(modules)
|
||||
self.overlaped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlaped_forward_backward")
|
||||
self.batch_dim = batch_dim
|
||||
self.group = process_group or dist.distributed_c10d._get_default_group()
|
||||
self.num_ranks = self.group.size()
|
||||
|
||||
# rank_mapping: Map rank in process_group to actual pp rank.
|
||||
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
|
||||
if rank_mapping is None:
|
||||
rank_mapping = list(range(self.num_ranks))
|
||||
rank_inverse_mapping = [None] * (self.num_ranks + 1)
|
||||
for i in range(self.num_ranks):
|
||||
rank_inverse_mapping[rank_mapping[i]] = i
|
||||
|
||||
self.rank = rank_mapping[self.group.rank()]
|
||||
self.first_rank = rank_inverse_mapping[0]
|
||||
self.prev_rank = rank_inverse_mapping[self.rank - 1]
|
||||
self.next_rank = rank_inverse_mapping[self.rank + 1]
|
||||
self.last_rank = rank_inverse_mapping[self.num_ranks - 1]
|
||||
|
||||
self.is_first_rank = self.rank == 0
|
||||
self.is_last_rank = self.rank == self.num_ranks - 1
|
||||
self.is_in_second_half = self.rank >= self.num_ranks // 2
|
||||
self.is_middle_rank = (self.rank == self.num_ranks // 2 - 1) or (self.rank == self.num_ranks // 2)
|
||||
|
||||
def _reset_states(self) -> None:
|
||||
WeightGradStore.clear()
|
||||
|
||||
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||
self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None
|
||||
self.loss_chunks: List[torch.Tensor] = []
|
||||
self.criterion: Callable = None
|
||||
|
||||
self.current_f_chunk_id: List[int] = [0, 0]
|
||||
self.current_b_chunk_id: List[int] = [0, 0]
|
||||
self.current_send_f_chunk_id: List[int] = [0, 0]
|
||||
self.current_send_b_chunk_id: List[int] = [0, 0]
|
||||
self.current_recv_f_chunk_id: List[int] = [0, 0]
|
||||
self.current_recv_b_chunk_id: List[int] = [0, 0]
|
||||
self.comm_ops: List[dist.P2POp] = []
|
||||
self.to_free: List[torch.Tensor] = []
|
||||
|
||||
def _forward_compute_chunk(self, phase: int) -> None:
|
||||
phase ^= self.is_in_second_half
|
||||
chunk_id = self.current_f_chunk_id[phase]
|
||||
self.current_f_chunk_id[phase] += 1
|
||||
inputs = self.input_chunks[phase][chunk_id]
|
||||
if self.forward_only:
|
||||
self.input_chunks[phase][chunk_id] = None
|
||||
|
||||
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
||||
|
||||
outputs = self.module[phase](*inputs)
|
||||
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
|
||||
if is_last_stage and self.criterion is not None:
|
||||
labels = self.labels[phase][chunk_id]
|
||||
loss = self.criterion(*outputs, *labels)
|
||||
self.loss_chunks.append(loss)
|
||||
|
||||
if (not is_last_stage) or self.return_outputs:
|
||||
self.output_chunks[phase].append(outputs)
|
||||
|
||||
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
|
||||
if self.forward_only:
|
||||
return
|
||||
|
||||
phase ^= self.is_in_second_half
|
||||
chunk_id = self.current_b_chunk_id[phase]
|
||||
self.current_b_chunk_id[phase] += 1
|
||||
|
||||
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
||||
|
||||
WeightGradStore.enabled = enable_zb
|
||||
if is_last_stage:
|
||||
loss = self.loss_chunks[chunk_id]
|
||||
loss.backward()
|
||||
loss.detach_()
|
||||
else:
|
||||
outputs = self.output_chunks[phase][chunk_id]
|
||||
if not self.return_outputs:
|
||||
self.output_chunks[phase][chunk_id] = None
|
||||
output_grads = self.output_grad_chunks[phase][chunk_id]
|
||||
self.output_grad_chunks[phase][chunk_id] = None
|
||||
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
|
||||
outputs, output_grads = list(zip(*non_empty))
|
||||
if len(outputs) > 0:
|
||||
run_backward(outputs, output_grads)
|
||||
WeightGradStore.enabled = False
|
||||
if enable_zb:
|
||||
WeightGradStore.flush()
|
||||
|
||||
inputs = self.input_chunks[phase][chunk_id]
|
||||
self.input_chunks[phase][chunk_id] = None
|
||||
input_grads = [t.grad for t in inputs]
|
||||
self.input_grad_chunks[phase].append(input_grads)
|
||||
|
||||
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
|
||||
if self.forward_only:
|
||||
self._forward_compute_chunk(phase0)
|
||||
return
|
||||
|
||||
if not self.overlaped_forward_backward:
|
||||
self._forward_compute_chunk(phase0)
|
||||
self._backward_compute_chunk(phase1)
|
||||
return
|
||||
|
||||
# pre-forward
|
||||
phase0 ^= self.is_in_second_half
|
||||
chunk_id0 = self.current_f_chunk_id[phase0]
|
||||
self.current_f_chunk_id[phase0] += 1
|
||||
module0 = self.module[phase0]
|
||||
inputs0 = self.input_chunks[phase0][chunk_id0]
|
||||
is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0)
|
||||
|
||||
if is_last_stage0 and self.criterion is not None:
|
||||
labels0 = self.labels[phase0][chunk_id0]
|
||||
criterion0 = self.criterion
|
||||
else:
|
||||
labels0 = []
|
||||
criterion0 = None
|
||||
|
||||
# pre-backward
|
||||
phase1 ^= self.is_in_second_half
|
||||
chunk_id1 = self.current_b_chunk_id[phase1]
|
||||
self.current_b_chunk_id[phase1] += 1
|
||||
module1 = self.module[phase1]
|
||||
is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0)
|
||||
|
||||
if is_last_stage1:
|
||||
loss1 = self.loss_chunks[chunk_id1]
|
||||
outputs1 = []
|
||||
output_grads1 = []
|
||||
else:
|
||||
loss1 = None
|
||||
outputs1 = self.output_chunks[phase1][chunk_id1]
|
||||
if not self.return_outputs:
|
||||
self.output_chunks[phase1][chunk_id1] = None
|
||||
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
|
||||
self.output_grad_chunks[phase1][chunk_id1] = None
|
||||
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
|
||||
outputs1, output_grads1 = list(zip(*non_empty))
|
||||
|
||||
# forward & backward
|
||||
outputs0, loss0 = type(module0).overlaped_forward_backward(
|
||||
module0, inputs0, criterion0, labels0,
|
||||
module1, loss1, outputs1, output_grads1,
|
||||
)
|
||||
|
||||
# post-forward
|
||||
if (not is_last_stage0) or self.return_outputs:
|
||||
self.output_chunks[phase0].append(outputs0)
|
||||
if is_last_stage0 and self.criterion is not None:
|
||||
self.loss_chunks.append(loss0)
|
||||
|
||||
# post-backward
|
||||
inputs = self.input_chunks[phase1][chunk_id1]
|
||||
self.input_chunks[phase1][chunk_id1] = None
|
||||
input_grads1 = [t.grad for t in inputs]
|
||||
self.input_grad_chunks[phase1].append(input_grads1)
|
||||
|
||||
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
|
||||
if recv:
|
||||
self._recv_forward(phase)
|
||||
self._commit_and_wait_comm()
|
||||
|
||||
self._forward_compute_chunk(phase)
|
||||
|
||||
if send:
|
||||
self._send_forward(phase)
|
||||
|
||||
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
|
||||
if recv:
|
||||
self._recv_backward(phase)
|
||||
self._commit_and_wait_comm()
|
||||
|
||||
self._backward_compute_chunk(phase, enable_zb)
|
||||
|
||||
if send:
|
||||
self._send_backward(phase)
|
||||
|
||||
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
|
||||
if recv0:
|
||||
self._recv_forward(phase0)
|
||||
self._recv_backward(phase1)
|
||||
self._commit_and_wait_comm()
|
||||
|
||||
self._forward_backward_compute_chunk(phase0, phase1)
|
||||
|
||||
self._send_forward(phase0)
|
||||
self._send_backward(phase1)
|
||||
|
||||
def _weight_chunk(self) -> None:
|
||||
if self.forward_only:
|
||||
return
|
||||
|
||||
self._commit_and_wait_comm()
|
||||
|
||||
# Assume FIFO
|
||||
WeightGradStore.pop()
|
||||
|
||||
def _free_tensors(self) -> None:
|
||||
for tensor in self.to_free:
|
||||
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
|
||||
tensor.data = torch.Tensor()
|
||||
self.to_free = []
|
||||
|
||||
def _recv_forward(self, phase: int) -> None:
|
||||
phase ^= self.is_in_second_half
|
||||
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
|
||||
if is_first_stage:
|
||||
return
|
||||
|
||||
self.current_recv_f_chunk_id[phase] += 1
|
||||
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
||||
self.input_chunks[phase].append(tensors)
|
||||
|
||||
def _send_forward(self, phase: int) -> None:
|
||||
phase ^= self.is_in_second_half
|
||||
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
||||
if is_last_stage:
|
||||
return
|
||||
|
||||
chunk_id = self.current_send_f_chunk_id[phase]
|
||||
self.current_send_f_chunk_id[phase] += 1
|
||||
tensors = self.output_chunks[phase][chunk_id]
|
||||
|
||||
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
||||
|
||||
if not self.return_outputs:
|
||||
self.to_free.extend(tensors)
|
||||
|
||||
def _recv_backward(self, phase: int) -> None:
|
||||
if self.forward_only:
|
||||
return
|
||||
|
||||
phase ^= self.is_in_second_half
|
||||
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
||||
if is_last_stage:
|
||||
return
|
||||
|
||||
self.current_recv_b_chunk_id[phase] += 1
|
||||
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
||||
self.output_grad_chunks[phase].append(tensors)
|
||||
|
||||
def _send_backward(self, phase: int) -> None:
|
||||
if self.forward_only:
|
||||
return
|
||||
|
||||
phase ^= self.is_in_second_half
|
||||
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
|
||||
if is_first_stage:
|
||||
return
|
||||
|
||||
chunk_id = self.current_send_b_chunk_id[phase]
|
||||
self.current_send_b_chunk_id[phase] += 1
|
||||
tensors = self.input_grad_chunks[phase][chunk_id]
|
||||
self.input_grad_chunks[phase][chunk_id] = None
|
||||
|
||||
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
||||
|
||||
def _commit_and_wait_comm(self) -> None:
|
||||
if not self.comm_ops:
|
||||
return
|
||||
reqs = dist.batch_isend_irecv(self.comm_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
self.comm_ops = []
|
||||
self._free_tensors()
|
||||
|
||||
def step(
|
||||
self,
|
||||
*inputs: Optional[torch.Tensor],
|
||||
num_chunks: int = 0,
|
||||
criterion: Optional[Callable] = None,
|
||||
labels: List[Optional[torch.Tensor]] = [],
|
||||
return_outputs: bool = False,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
|
||||
"""
|
||||
Execute a traning or inference step.
|
||||
|
||||
Arguments:
|
||||
*inputs: Module inputs. Required only on the first/last ranks.
|
||||
num_chunks: The number of micro-batches.
|
||||
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks.
|
||||
labels: Labels of the loss function. Required only on the first/last ranks.
|
||||
labels on the first rank corresponds to inputs on the last rank.
|
||||
labels on the last rank corresponds to inputs on the first rank.
|
||||
return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``.
|
||||
|
||||
Returns: (loss, outputs)
|
||||
loss: Loss for the batch.
|
||||
loss on the first rank corresponds to inputs on the last rank.
|
||||
loss on the last rank corresponds to inputs on the first rank.
|
||||
Otherwise: ``None``.
|
||||
outputs: Returned only if ``return_outputs=True``.
|
||||
outputs on the first rank corresponds to inputs on the last rank.
|
||||
outputs on the last rank corresponds to inputs on the first rank.
|
||||
Otherwise: ``None``.
|
||||
|
||||
"""
|
||||
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
|
||||
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before doing a step."
|
||||
self.forward_only = not torch.is_grad_enabled()
|
||||
self.return_outputs = return_outputs
|
||||
|
||||
rank = self.rank
|
||||
num_ranks = self.num_ranks
|
||||
assert num_ranks % 2 == 0
|
||||
assert num_chunks > 0 and num_chunks % 2 == 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
|
||||
num_half_ranks = num_ranks // 2
|
||||
half_rank = min(rank, num_ranks - 1 - rank)
|
||||
half_num_chunks = num_chunks // 2
|
||||
self.num_half_ranks = num_half_ranks
|
||||
self.half_rank = half_rank
|
||||
|
||||
if not self.forward_only and (self.is_first_rank or self.is_last_rank):
|
||||
assert criterion is not None
|
||||
|
||||
self._reset_states()
|
||||
|
||||
inputs = scatter(inputs, half_num_chunks, self.batch_dim)
|
||||
labels = scatter(labels, half_num_chunks, self.batch_dim)
|
||||
if self.is_first_rank:
|
||||
self.input_chunks = (inputs, [])
|
||||
self.labels = ([], labels)
|
||||
elif self.is_last_rank:
|
||||
self.input_chunks = ([], inputs)
|
||||
self.labels = (labels, [])
|
||||
self.criterion = criterion
|
||||
|
||||
# For the fisrt half of the ranks: phase 0 means forward direction, phase 1 means reverse direction.
|
||||
# For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction.
|
||||
|
||||
# Step 1: nF0
|
||||
step_1 = (num_half_ranks - half_rank - 1) * 2
|
||||
for i in range(step_1):
|
||||
self._forward_chunk(0)
|
||||
|
||||
# Step 2: nF0F1
|
||||
step_2 = half_rank + 1
|
||||
self._recv_forward(0)
|
||||
for i in range(step_2):
|
||||
self._forward_chunk(0, recv=False, send=self.is_middle_rank)
|
||||
self._recv_forward(0)
|
||||
self._forward_chunk(1, send=(not self.is_middle_rank) or (i < step_2 - 1))
|
||||
if not self.is_middle_rank:
|
||||
self._send_forward(0)
|
||||
|
||||
# Step 3: nB1W1F1 (Use zero bubble)
|
||||
step_3 = num_half_ranks - half_rank - 1
|
||||
for i in range(step_3):
|
||||
self._backward_chunk(1, enable_zb=True)
|
||||
self._recv_forward(1)
|
||||
self._weight_chunk()
|
||||
self._forward_chunk(1, recv=False)
|
||||
|
||||
# Step 4 (Main step): nF0B1F1B0
|
||||
step_4 = half_num_chunks - num_ranks + half_rank + 1
|
||||
for i in range(step_4):
|
||||
if i == 0:
|
||||
if self.is_middle_rank:
|
||||
# NOTE: We don't overlap these two chunks to further reduce bubble size.
|
||||
self._forward_chunk(0, recv=False, send=False)
|
||||
self._send_forward(1)
|
||||
self._backward_chunk(1, send=False)
|
||||
self._send_forward(0)
|
||||
self._send_backward(1)
|
||||
else:
|
||||
self._forward_backward_chunk(0, 1, recv0=False)
|
||||
else:
|
||||
self._forward_backward_chunk(0, 1)
|
||||
self._forward_backward_chunk(1, 0)
|
||||
|
||||
# Step 5: nB1F1B0
|
||||
step_5 = num_half_ranks - half_rank - 1
|
||||
for i in range(step_5):
|
||||
self._backward_chunk(1)
|
||||
self._forward_backward_chunk(1, 0)
|
||||
|
||||
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
|
||||
step_6 = half_rank + 1
|
||||
enable_zb = False
|
||||
for i in range(step_6):
|
||||
if i == step_6 // 2 and half_rank % 2 == 1:
|
||||
enable_zb = True
|
||||
self._backward_chunk(1, enable_zb=enable_zb)
|
||||
if i == step_6 // 2 and half_rank % 2 == 0:
|
||||
enable_zb = True
|
||||
self._backward_chunk(0, enable_zb=enable_zb)
|
||||
|
||||
# Step 7: nWB0 (Use zero bubble)
|
||||
step_7 = num_half_ranks - half_rank - 1
|
||||
for i in range(step_7):
|
||||
self._weight_chunk()
|
||||
self._backward_chunk(0, enable_zb=True)
|
||||
|
||||
# Step 8: nW
|
||||
step_8 = half_rank + 1
|
||||
for i in range(step_8):
|
||||
self._weight_chunk()
|
||||
assert WeightGradStore.funcs_queue.empty()
|
||||
|
||||
self._commit_and_wait_comm()
|
||||
|
||||
loss, outputs = None, None
|
||||
if self.is_first_rank or self.is_last_rank:
|
||||
if criterion is not None:
|
||||
loss = torch.stack(self.loss_chunks)
|
||||
if return_outputs:
|
||||
outputs = gather(self.output_chunks[self.is_first_rank], self.batch_dim)
|
||||
if len(outputs) == 1:
|
||||
outputs = outputs[0]
|
||||
|
||||
self._reset_states()
|
||||
|
||||
return loss, outputs
|
80
dualpipe/utils.py
Normal file
80
dualpipe/utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
import queue
|
||||
from typing import List, Callable
|
||||
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class WeightGradStore:
|
||||
|
||||
enabled: bool = False
|
||||
cache: List[Callable] = []
|
||||
funcs_queue = queue.Queue()
|
||||
|
||||
@classmethod
|
||||
def put(cls, func: Callable) -> None:
|
||||
cls.cache.append(func)
|
||||
|
||||
@classmethod
|
||||
def flush(cls) -> None:
|
||||
cls.funcs_queue.put(cls.cache)
|
||||
cls.cache = []
|
||||
|
||||
@classmethod
|
||||
def pop(cls) -> None:
|
||||
assert not cls.funcs_queue.empty(), "Pop empty queue."
|
||||
funcs = cls.funcs_queue.get()
|
||||
for func in funcs:
|
||||
func()
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls.cache = []
|
||||
cls.funcs_queue = queue.Queue()
|
||||
|
||||
|
||||
def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None:
|
||||
kwargs = dict(
|
||||
keep_graph=False,
|
||||
create_graph=False,
|
||||
allow_unreachable=True,
|
||||
accumulate_grad=True,
|
||||
)
|
||||
Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs)
|
||||
|
||||
|
||||
def chunk_tensor(x, chunks, dim):
|
||||
if x is None:
|
||||
return [None for _ in range(chunks)]
|
||||
return x.tensor_split(chunks, dim=dim)
|
||||
|
||||
|
||||
def cat_tensor(x, dim):
|
||||
if (isinstance(x, tuple) or isinstance(x, list)):
|
||||
if len(x) == 1:
|
||||
return x[0]
|
||||
elif x[0] is None:
|
||||
assert all(y is None for y in x)
|
||||
return None
|
||||
return torch.cat(x, dim=dim)
|
||||
|
||||
|
||||
def scatter(inputs, chunks, dim):
|
||||
assert isinstance(inputs, (torch.Tensor, tuple, list))
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = (inputs,)
|
||||
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
|
||||
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
|
||||
microbatches = [microbatch for microbatch in zip(*inputs)]
|
||||
if len(microbatches) == 0:
|
||||
microbatches = [() for _ in range(chunks)]
|
||||
return microbatches
|
||||
|
||||
|
||||
def gather(micro_outputs, dim):
|
||||
assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list))
|
||||
if isinstance(micro_outputs[0], torch.Tensor):
|
||||
micro_outputs = [(x,) for x in micro_outputs]
|
||||
outputs = [x for x in zip(*micro_outputs)]
|
||||
outputs = tuple(cat_tensor(x, dim=dim) for x in outputs)
|
||||
return outputs
|
202
example.py
Normal file
202
example.py
Normal file
@ -0,0 +1,202 @@
|
||||
from typing import List, Optional, Callable, Tuple
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from dualpipe import DualPipe, set_p2p_tensor_shapes, set_p2p_tensor_dtype
|
||||
from dualpipe.utils import WeightGradStore, run_backward
|
||||
|
||||
|
||||
class LinearFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
output = F.linear(input, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
if weight.grad is None:
|
||||
weight.grad = torch.zeros_like(weight)
|
||||
|
||||
def grad_weight_fn():
|
||||
weight.grad += grad_output.flatten(0, -2).T @ input.flatten(0, -2)
|
||||
|
||||
if WeightGradStore.enabled:
|
||||
WeightGradStore.put(grad_weight_fn)
|
||||
else:
|
||||
grad_weight_fn()
|
||||
grad_input = grad_output @ weight
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class MyLinear(nn.Linear):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return LinearFunc.apply(input, self.weight)
|
||||
|
||||
|
||||
class PipelineStage(nn.Module):
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = MyLinear(hidden_size, hidden_size * 4, bias=False)
|
||||
self.linear2 = MyLinear(hidden_size * 4, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = F.gelu(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def overlaped_forward_backward(
|
||||
cls,
|
||||
module0: "PipelineStage",
|
||||
inputs0: List[torch.Tensor],
|
||||
criterion0: Optional[Callable],
|
||||
labels0: Optional[List[torch.Tensor]],
|
||||
module1: "PipelineStage",
|
||||
loss1: Optional[torch.Tensor],
|
||||
outputs1: Optional[List[torch.Tensor]],
|
||||
output_grads1: Optional[List[torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
You should implement custom forward-backward overlap strategy.
|
||||
The code below is just an example.
|
||||
"""
|
||||
outputs0 = module0(*inputs0)
|
||||
outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0
|
||||
if criterion0 is not None:
|
||||
loss0 = criterion0(*outputs0, *labels0)
|
||||
else:
|
||||
loss0 = None
|
||||
|
||||
if loss1 is not None:
|
||||
loss1.backward()
|
||||
loss1.detach_()
|
||||
else:
|
||||
run_backward(outputs1, output_grads1)
|
||||
|
||||
return outputs0, loss0
|
||||
|
||||
|
||||
def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(output, target).clone()
|
||||
|
||||
|
||||
def ref_step(x, l, model, chunks):
|
||||
ys, losses = [], []
|
||||
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
|
||||
micro_y = model(micro_x)
|
||||
loss = criterion(micro_y, micro_l)
|
||||
loss.backward()
|
||||
ys.append(micro_y)
|
||||
losses.append(loss)
|
||||
y = torch.cat(ys, 0)
|
||||
loss = torch.stack(losses)
|
||||
return loss, y
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item()
|
||||
return cos_diff
|
||||
|
||||
|
||||
def main(rank, pp_size):
|
||||
is_first_rank = rank == 0
|
||||
is_last_rank = rank == pp_size - 1
|
||||
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
torch.set_default_device(f"cuda:{rank}")
|
||||
torch.manual_seed(233)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
num_chunks = 20
|
||||
micro_batch_size = 3
|
||||
seq_len = 256
|
||||
hidden_size = 512
|
||||
if is_first_rank:
|
||||
print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}", flush=True)
|
||||
set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)])
|
||||
set_p2p_tensor_dtype(torch.float32)
|
||||
|
||||
# Create a model and partition it for each process
|
||||
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)])
|
||||
|
||||
# Full inputs
|
||||
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||
|
||||
# Reference step
|
||||
loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks)
|
||||
|
||||
# DualPipe
|
||||
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size - 1 - rank])
|
||||
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
|
||||
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
|
||||
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
|
||||
dualpipe_model = DualPipe(local_modules)
|
||||
|
||||
# DualPipe inputs
|
||||
if is_first_rank:
|
||||
x = full_x.chunk(2)[0]
|
||||
l = full_l.chunk(2)[1]
|
||||
elif is_last_rank:
|
||||
x = full_x.chunk(2)[1]
|
||||
l = full_l.chunk(2)[0]
|
||||
else:
|
||||
x = None
|
||||
l = None
|
||||
|
||||
# Training step
|
||||
loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=False)
|
||||
|
||||
# Check loss
|
||||
if is_first_rank:
|
||||
assert torch.equal(loss, loss_ref.chunk(2)[1])
|
||||
elif is_last_rank:
|
||||
assert torch.equal(loss, loss_ref.chunk(2)[0])
|
||||
else:
|
||||
assert loss is None
|
||||
assert outputs is None
|
||||
|
||||
# Check grads
|
||||
for (p0, p1) in zip(local_modules[0].parameters(), local_modules[1].parameters()):
|
||||
p0all = torch.empty(pp_size, *p0.shape)
|
||||
p1all = torch.empty(pp_size, *p1.shape)
|
||||
dist.all_gather_into_tensor(p0all, p0.grad)
|
||||
dist.all_gather_into_tensor(p1all, p1.grad)
|
||||
p0.grad += p1all[pp_size - 1 - rank]
|
||||
p1.grad += p0all[pp_size - 1 - rank]
|
||||
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
|
||||
assert cal_diff(p.grad, p_ref.grad) < 1e-13
|
||||
dualpipe_model.zero_grad()
|
||||
|
||||
# Inference step
|
||||
with torch.no_grad():
|
||||
loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=True)
|
||||
|
||||
# Check loss and outputs
|
||||
if is_first_rank:
|
||||
assert torch.equal(loss, loss_ref.chunk(2)[1])
|
||||
assert torch.equal(outputs, output_ref.chunk(2)[1])
|
||||
elif is_last_rank:
|
||||
assert torch.equal(loss, loss_ref.chunk(2)[0])
|
||||
assert torch.equal(outputs, output_ref.chunk(2)[0])
|
||||
else:
|
||||
assert loss is None
|
||||
assert outputs is None
|
||||
|
||||
|
||||
def test_dualpipe(ngpus):
|
||||
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_gpus = torch.cuda.device_count() // 2 * 2
|
||||
for ngpus in range(num_gpus, 0, -2):
|
||||
test_dualpipe(ngpus)
|
BIN
images/schedules.png
Normal file
BIN
images/schedules.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 80 KiB |
18
setup.py
Normal file
18
setup.py
Normal file
@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
import subprocess
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
try:
|
||||
cmd = ['git', 'rev-parse', '--short', 'HEAD']
|
||||
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
|
||||
except Exception as _:
|
||||
now = datetime.now()
|
||||
date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
rev = '+' + date_time_str
|
||||
|
||||
setup(
|
||||
name="dualpipe",
|
||||
version="1.0.0" + rev,
|
||||
)
|
Loading…
Reference in New Issue
Block a user