Initial commit

This commit is contained in:
ljss 2025-02-20 16:36:16 +08:00
commit fbe0ac0d6e
10 changed files with 876 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
build
*.egg-info/
__pycache__/
dist/

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

56
README.md Normal file
View File

@ -0,0 +1,56 @@
# DualPipe
DualPipe is an innovative bidirectional pipeline parallism algorithm introduced in the [DeepSeek-V3 Technical Report](https://arxiv.org/pdf/2412.19437). It achieves full overlap of forward and backward computation-communication phases, also reducing pipeline bubbles. For detailed information on computation-communication overlap, please refer to the [profile data](https://github.com/deepseek-ai/profile-data).
### Schedules
![schedules](images/schedules.png)
Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions.
The micro-batches in the reverse direction are symmetric to those in the forward direction, so
we omit their batch ID for illustration simplicity. Two cells enclosed by a shared black border
have mutually overlapped computation and communication
### Pipeline Bubbles and Memory Usage Comparison
| Method | Bubble | Parameter | Activation |
|-------------|---------------------------------|-----------|------------|
| 1F1B | (PP-1)(𝐹+𝑊) | 1× | PP |
| ZB1P | (PP-1)(𝐹+𝐵-2𝑊) | 1× | PP |
| DualPipe | (PP/2-1)(𝐹&𝐵+𝐵-3𝑊) | 2× | PP+1 |
𝐹 denotes the execution time of a forward chunk, 𝐵 denotes the execution time of a
full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵
denotes the execution time of two mutually overlapped forward and backward chunks.
## Quick Start
The usage is shown in the following example:
```bash
python example.py
```
Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module.
## Requirements
- PyTorch 2.0 and above
## Developers
DualPipe was created and developed by Jiashi Li and Chengqi Deng and Wenfeng Liang.
## Citation
```bibtex
@misc{deepseekai2024deepseekv3technicalreport,
title={DeepSeek-V3 Technical Report},
author={DeepSeek-AI},
year={2024},
eprint={2412.19437},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.19437},
}
```

17
dualpipe/__init__.py Normal file
View File

@ -0,0 +1,17 @@
__version__ = "1.0.0"
from dualpipe.dualpipe import (
DualPipe,
WeightGradStore,
)
from dualpipe.comm import (
set_p2p_tensor_shapes,
set_p2p_tensor_dtype,
)
__all__ = [
DualPipe,
WeightGradStore,
set_p2p_tensor_shapes,
set_p2p_tensor_dtype,
]

38
dualpipe/comm.py Normal file
View File

@ -0,0 +1,38 @@
from typing import List, Tuple
import torch
import torch.distributed as dist
TENSOR_SHAPES: List[Tuple[int]] = None
TENSOR_DTYPE: torch.dtype = None
def set_p2p_tensor_shapes(shapes: List[Tuple[int]]):
global TENSOR_SHAPES
TENSOR_SHAPES = shapes
def set_p2p_tensor_dtype(dtype: torch.dtype):
global TENSOR_DTYPE
TENSOR_DTYPE = dtype
def build_from_tensor_shapes():
return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES]
def append_irecv(ops: List[dist.P2POp], src: int, group: dist.ProcessGroup) -> List[torch.Tensor]:
tensors = build_from_tensor_shapes()
src = dist.distributed_c10d.get_global_rank(group, src)
for tensor in tensors:
if tensor is not None:
ops.append(dist.P2POp(dist.irecv, tensor, src))
return tensors
def append_isend(ops: List[dist.P2POp], tensors: List[torch.Tensor], dst: int, group: dist.ProcessGroup) -> None:
dst = dist.distributed_c10d.get_global_rank(group, dst)
for tensor in tensors:
if tensor is not None:
ops.append(dist.P2POp(dist.isend, tensor, dst))

440
dualpipe/dualpipe.py Normal file
View File

@ -0,0 +1,440 @@
from typing import Tuple, List, Union, Callable, Optional
import torch
import torch.nn as nn
import torch.distributed as dist
import dualpipe.comm as comm
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
class DualPipe(nn.Module):
def __init__(
self,
modules: Tuple[nn.Module, nn.Module],
batch_dim: int = 0,
process_group: Optional[dist.ProcessGroup] = None,
rank_mapping: Optional[List[int]] = None,
) -> None:
super().__init__()
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
self.module = nn.ModuleList(modules)
self.overlaped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlaped_forward_backward")
self.batch_dim = batch_dim
self.group = process_group or dist.distributed_c10d._get_default_group()
self.num_ranks = self.group.size()
# rank_mapping: Map rank in process_group to actual pp rank.
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
if rank_mapping is None:
rank_mapping = list(range(self.num_ranks))
rank_inverse_mapping = [None] * (self.num_ranks + 1)
for i in range(self.num_ranks):
rank_inverse_mapping[rank_mapping[i]] = i
self.rank = rank_mapping[self.group.rank()]
self.first_rank = rank_inverse_mapping[0]
self.prev_rank = rank_inverse_mapping[self.rank - 1]
self.next_rank = rank_inverse_mapping[self.rank + 1]
self.last_rank = rank_inverse_mapping[self.num_ranks - 1]
self.is_first_rank = self.rank == 0
self.is_last_rank = self.rank == self.num_ranks - 1
self.is_in_second_half = self.rank >= self.num_ranks // 2
self.is_middle_rank = (self.rank == self.num_ranks // 2 - 1) or (self.rank == self.num_ranks // 2)
def _reset_states(self) -> None:
WeightGradStore.clear()
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None
self.loss_chunks: List[torch.Tensor] = []
self.criterion: Callable = None
self.current_f_chunk_id: List[int] = [0, 0]
self.current_b_chunk_id: List[int] = [0, 0]
self.current_send_f_chunk_id: List[int] = [0, 0]
self.current_send_b_chunk_id: List[int] = [0, 0]
self.current_recv_f_chunk_id: List[int] = [0, 0]
self.current_recv_b_chunk_id: List[int] = [0, 0]
self.comm_ops: List[dist.P2POp] = []
self.to_free: List[torch.Tensor] = []
def _forward_compute_chunk(self, phase: int) -> None:
phase ^= self.is_in_second_half
chunk_id = self.current_f_chunk_id[phase]
self.current_f_chunk_id[phase] += 1
inputs = self.input_chunks[phase][chunk_id]
if self.forward_only:
self.input_chunks[phase][chunk_id] = None
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
outputs = self.module[phase](*inputs)
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
if is_last_stage and self.criterion is not None:
labels = self.labels[phase][chunk_id]
loss = self.criterion(*outputs, *labels)
self.loss_chunks.append(loss)
if (not is_last_stage) or self.return_outputs:
self.output_chunks[phase].append(outputs)
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
if self.forward_only:
return
phase ^= self.is_in_second_half
chunk_id = self.current_b_chunk_id[phase]
self.current_b_chunk_id[phase] += 1
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
WeightGradStore.enabled = enable_zb
if is_last_stage:
loss = self.loss_chunks[chunk_id]
loss.backward()
loss.detach_()
else:
outputs = self.output_chunks[phase][chunk_id]
if not self.return_outputs:
self.output_chunks[phase][chunk_id] = None
output_grads = self.output_grad_chunks[phase][chunk_id]
self.output_grad_chunks[phase][chunk_id] = None
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
outputs, output_grads = list(zip(*non_empty))
if len(outputs) > 0:
run_backward(outputs, output_grads)
WeightGradStore.enabled = False
if enable_zb:
WeightGradStore.flush()
inputs = self.input_chunks[phase][chunk_id]
self.input_chunks[phase][chunk_id] = None
input_grads = [t.grad for t in inputs]
self.input_grad_chunks[phase].append(input_grads)
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
if self.forward_only:
self._forward_compute_chunk(phase0)
return
if not self.overlaped_forward_backward:
self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1)
return
# pre-forward
phase0 ^= self.is_in_second_half
chunk_id0 = self.current_f_chunk_id[phase0]
self.current_f_chunk_id[phase0] += 1
module0 = self.module[phase0]
inputs0 = self.input_chunks[phase0][chunk_id0]
is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0)
if is_last_stage0 and self.criterion is not None:
labels0 = self.labels[phase0][chunk_id0]
criterion0 = self.criterion
else:
labels0 = []
criterion0 = None
# pre-backward
phase1 ^= self.is_in_second_half
chunk_id1 = self.current_b_chunk_id[phase1]
self.current_b_chunk_id[phase1] += 1
module1 = self.module[phase1]
is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0)
if is_last_stage1:
loss1 = self.loss_chunks[chunk_id1]
outputs1 = []
output_grads1 = []
else:
loss1 = None
outputs1 = self.output_chunks[phase1][chunk_id1]
if not self.return_outputs:
self.output_chunks[phase1][chunk_id1] = None
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
self.output_grad_chunks[phase1][chunk_id1] = None
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward
outputs0, loss0 = type(module0).overlaped_forward_backward(
module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1,
)
# post-forward
if (not is_last_stage0) or self.return_outputs:
self.output_chunks[phase0].append(outputs0)
if is_last_stage0 and self.criterion is not None:
self.loss_chunks.append(loss0)
# post-backward
inputs = self.input_chunks[phase1][chunk_id1]
self.input_chunks[phase1][chunk_id1] = None
input_grads1 = [t.grad for t in inputs]
self.input_grad_chunks[phase1].append(input_grads1)
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_forward(phase)
self._commit_and_wait_comm()
self._forward_compute_chunk(phase)
if send:
self._send_forward(phase)
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_backward(phase)
self._commit_and_wait_comm()
self._backward_compute_chunk(phase, enable_zb)
if send:
self._send_backward(phase)
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
if recv0:
self._recv_forward(phase0)
self._recv_backward(phase1)
self._commit_and_wait_comm()
self._forward_backward_compute_chunk(phase0, phase1)
self._send_forward(phase0)
self._send_backward(phase1)
def _weight_chunk(self) -> None:
if self.forward_only:
return
self._commit_and_wait_comm()
# Assume FIFO
WeightGradStore.pop()
def _free_tensors(self) -> None:
for tensor in self.to_free:
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
tensor.data = torch.Tensor()
self.to_free = []
def _recv_forward(self, phase: int) -> None:
phase ^= self.is_in_second_half
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
if is_first_stage:
return
self.current_recv_f_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
self.input_chunks[phase].append(tensors)
def _send_forward(self, phase: int) -> None:
phase ^= self.is_in_second_half
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
if is_last_stage:
return
chunk_id = self.current_send_f_chunk_id[phase]
self.current_send_f_chunk_id[phase] += 1
tensors = self.output_chunks[phase][chunk_id]
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
if not self.return_outputs:
self.to_free.extend(tensors)
def _recv_backward(self, phase: int) -> None:
if self.forward_only:
return
phase ^= self.is_in_second_half
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
if is_last_stage:
return
self.current_recv_b_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
self.output_grad_chunks[phase].append(tensors)
def _send_backward(self, phase: int) -> None:
if self.forward_only:
return
phase ^= self.is_in_second_half
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
if is_first_stage:
return
chunk_id = self.current_send_b_chunk_id[phase]
self.current_send_b_chunk_id[phase] += 1
tensors = self.input_grad_chunks[phase][chunk_id]
self.input_grad_chunks[phase][chunk_id] = None
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
def _commit_and_wait_comm(self) -> None:
if not self.comm_ops:
return
reqs = dist.batch_isend_irecv(self.comm_ops)
for req in reqs:
req.wait()
self.comm_ops = []
self._free_tensors()
def step(
self,
*inputs: Optional[torch.Tensor],
num_chunks: int = 0,
criterion: Optional[Callable] = None,
labels: List[Optional[torch.Tensor]] = [],
return_outputs: bool = False,
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
"""
Execute a traning or inference step.
Arguments:
*inputs: Module inputs. Required only on the first/last ranks.
num_chunks: The number of micro-batches.
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks.
labels: Labels of the loss function. Required only on the first/last ranks.
labels on the first rank corresponds to inputs on the last rank.
labels on the last rank corresponds to inputs on the first rank.
return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``.
Returns: (loss, outputs)
loss: Loss for the batch.
loss on the first rank corresponds to inputs on the last rank.
loss on the last rank corresponds to inputs on the first rank.
Otherwise: ``None``.
outputs: Returned only if ``return_outputs=True``.
outputs on the first rank corresponds to inputs on the last rank.
outputs on the last rank corresponds to inputs on the first rank.
Otherwise: ``None``.
"""
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before doing a step."
self.forward_only = not torch.is_grad_enabled()
self.return_outputs = return_outputs
rank = self.rank
num_ranks = self.num_ranks
assert num_ranks % 2 == 0
assert num_chunks > 0 and num_chunks % 2 == 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
num_half_ranks = num_ranks // 2
half_rank = min(rank, num_ranks - 1 - rank)
half_num_chunks = num_chunks // 2
self.num_half_ranks = num_half_ranks
self.half_rank = half_rank
if not self.forward_only and (self.is_first_rank or self.is_last_rank):
assert criterion is not None
self._reset_states()
inputs = scatter(inputs, half_num_chunks, self.batch_dim)
labels = scatter(labels, half_num_chunks, self.batch_dim)
if self.is_first_rank:
self.input_chunks = (inputs, [])
self.labels = ([], labels)
elif self.is_last_rank:
self.input_chunks = ([], inputs)
self.labels = (labels, [])
self.criterion = criterion
# For the fisrt half of the ranks: phase 0 means forward direction, phase 1 means reverse direction.
# For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction.
# Step 1: nF0
step_1 = (num_half_ranks - half_rank - 1) * 2
for i in range(step_1):
self._forward_chunk(0)
# Step 2: nF0F1
step_2 = half_rank + 1
self._recv_forward(0)
for i in range(step_2):
self._forward_chunk(0, recv=False, send=self.is_middle_rank)
self._recv_forward(0)
self._forward_chunk(1, send=(not self.is_middle_rank) or (i < step_2 - 1))
if not self.is_middle_rank:
self._send_forward(0)
# Step 3: nB1W1F1 (Use zero bubble)
step_3 = num_half_ranks - half_rank - 1
for i in range(step_3):
self._backward_chunk(1, enable_zb=True)
self._recv_forward(1)
self._weight_chunk()
self._forward_chunk(1, recv=False)
# Step 4 (Main step): nF0B1F1B0
step_4 = half_num_chunks - num_ranks + half_rank + 1
for i in range(step_4):
if i == 0:
if self.is_middle_rank:
# NOTE: We don't overlap these two chunks to further reduce bubble size.
self._forward_chunk(0, recv=False, send=False)
self._send_forward(1)
self._backward_chunk(1, send=False)
self._send_forward(0)
self._send_backward(1)
else:
self._forward_backward_chunk(0, 1, recv0=False)
else:
self._forward_backward_chunk(0, 1)
self._forward_backward_chunk(1, 0)
# Step 5: nB1F1B0
step_5 = num_half_ranks - half_rank - 1
for i in range(step_5):
self._backward_chunk(1)
self._forward_backward_chunk(1, 0)
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
step_6 = half_rank + 1
enable_zb = False
for i in range(step_6):
if i == step_6 // 2 and half_rank % 2 == 1:
enable_zb = True
self._backward_chunk(1, enable_zb=enable_zb)
if i == step_6 // 2 and half_rank % 2 == 0:
enable_zb = True
self._backward_chunk(0, enable_zb=enable_zb)
# Step 7: nWB0 (Use zero bubble)
step_7 = num_half_ranks - half_rank - 1
for i in range(step_7):
self._weight_chunk()
self._backward_chunk(0, enable_zb=True)
# Step 8: nW
step_8 = half_rank + 1
for i in range(step_8):
self._weight_chunk()
assert WeightGradStore.funcs_queue.empty()
self._commit_and_wait_comm()
loss, outputs = None, None
if self.is_first_rank or self.is_last_rank:
if criterion is not None:
loss = torch.stack(self.loss_chunks)
if return_outputs:
outputs = gather(self.output_chunks[self.is_first_rank], self.batch_dim)
if len(outputs) == 1:
outputs = outputs[0]
self._reset_states()
return loss, outputs

80
dualpipe/utils.py Normal file
View File

@ -0,0 +1,80 @@
import queue
from typing import List, Callable
import torch
from torch.autograd import Variable
class WeightGradStore:
enabled: bool = False
cache: List[Callable] = []
funcs_queue = queue.Queue()
@classmethod
def put(cls, func: Callable) -> None:
cls.cache.append(func)
@classmethod
def flush(cls) -> None:
cls.funcs_queue.put(cls.cache)
cls.cache = []
@classmethod
def pop(cls) -> None:
assert not cls.funcs_queue.empty(), "Pop empty queue."
funcs = cls.funcs_queue.get()
for func in funcs:
func()
@classmethod
def clear(cls) -> None:
cls.cache = []
cls.funcs_queue = queue.Queue()
def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None:
kwargs = dict(
keep_graph=False,
create_graph=False,
allow_unreachable=True,
accumulate_grad=True,
)
Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs)
def chunk_tensor(x, chunks, dim):
if x is None:
return [None for _ in range(chunks)]
return x.tensor_split(chunks, dim=dim)
def cat_tensor(x, dim):
if (isinstance(x, tuple) or isinstance(x, list)):
if len(x) == 1:
return x[0]
elif x[0] is None:
assert all(y is None for y in x)
return None
return torch.cat(x, dim=dim)
def scatter(inputs, chunks, dim):
assert isinstance(inputs, (torch.Tensor, tuple, list))
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
assert all(x is None or isinstance(x, torch.Tensor) for x in inputs)
inputs = [chunk_tensor(x, chunks, dim) for x in inputs]
microbatches = [microbatch for microbatch in zip(*inputs)]
if len(microbatches) == 0:
microbatches = [() for _ in range(chunks)]
return microbatches
def gather(micro_outputs, dim):
assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list))
if isinstance(micro_outputs[0], torch.Tensor):
micro_outputs = [(x,) for x in micro_outputs]
outputs = [x for x in zip(*micro_outputs)]
outputs = tuple(cat_tensor(x, dim=dim) for x in outputs)
return outputs

202
example.py Normal file
View File

@ -0,0 +1,202 @@
from typing import List, Optional, Callable, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from dualpipe import DualPipe, set_p2p_tensor_shapes, set_p2p_tensor_dtype
from dualpipe.utils import WeightGradStore, run_backward
class LinearFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = F.linear(input, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
if weight.grad is None:
weight.grad = torch.zeros_like(weight)
def grad_weight_fn():
weight.grad += grad_output.flatten(0, -2).T @ input.flatten(0, -2)
if WeightGradStore.enabled:
WeightGradStore.put(grad_weight_fn)
else:
grad_weight_fn()
grad_input = grad_output @ weight
return grad_input, None
class MyLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return LinearFunc.apply(input, self.weight)
class PipelineStage(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.linear1 = MyLinear(hidden_size, hidden_size * 4, bias=False)
self.linear2 = MyLinear(hidden_size * 4, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = F.gelu(x)
x = self.linear2(x)
return x
@classmethod
def overlaped_forward_backward(
cls,
module0: "PipelineStage",
inputs0: List[torch.Tensor],
criterion0: Optional[Callable],
labels0: Optional[List[torch.Tensor]],
module1: "PipelineStage",
loss1: Optional[torch.Tensor],
outputs1: Optional[List[torch.Tensor]],
output_grads1: Optional[List[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
You should implement custom forward-backward overlap strategy.
The code below is just an example.
"""
outputs0 = module0(*inputs0)
outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0
if criterion0 is not None:
loss0 = criterion0(*outputs0, *labels0)
else:
loss0 = None
if loss1 is not None:
loss1.backward()
loss1.detach_()
else:
run_backward(outputs1, output_grads1)
return outputs0, loss0
def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(output, target).clone()
def ref_step(x, l, model, chunks):
ys, losses = [], []
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
micro_y = model(micro_x)
loss = criterion(micro_y, micro_l)
loss.backward()
ys.append(micro_y)
losses.append(loss)
y = torch.cat(ys, 0)
loss = torch.stack(losses)
return loss, y
def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item()
return cos_diff
def main(rank, pp_size):
is_first_rank = rank == 0
is_last_rank = rank == pp_size - 1
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
torch.cuda.set_device(rank)
torch.set_default_device(f"cuda:{rank}")
torch.manual_seed(233)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
num_chunks = 20
micro_batch_size = 3
seq_len = 256
hidden_size = 512
if is_first_rank:
print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}", flush=True)
set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)])
set_p2p_tensor_dtype(torch.float32)
# Create a model and partition it for each process
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)])
# Full inputs
full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
# Reference step
loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks)
# DualPipe
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size - 1 - rank])
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
dualpipe_model = DualPipe(local_modules)
# DualPipe inputs
if is_first_rank:
x = full_x.chunk(2)[0]
l = full_l.chunk(2)[1]
elif is_last_rank:
x = full_x.chunk(2)[1]
l = full_l.chunk(2)[0]
else:
x = None
l = None
# Training step
loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=False)
# Check loss
if is_first_rank:
assert torch.equal(loss, loss_ref.chunk(2)[1])
elif is_last_rank:
assert torch.equal(loss, loss_ref.chunk(2)[0])
else:
assert loss is None
assert outputs is None
# Check grads
for (p0, p1) in zip(local_modules[0].parameters(), local_modules[1].parameters()):
p0all = torch.empty(pp_size, *p0.shape)
p1all = torch.empty(pp_size, *p1.shape)
dist.all_gather_into_tensor(p0all, p0.grad)
dist.all_gather_into_tensor(p1all, p1.grad)
p0.grad += p1all[pp_size - 1 - rank]
p1.grad += p0all[pp_size - 1 - rank]
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
assert cal_diff(p.grad, p_ref.grad) < 1e-13
dualpipe_model.zero_grad()
# Inference step
with torch.no_grad():
loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=True)
# Check loss and outputs
if is_first_rank:
assert torch.equal(loss, loss_ref.chunk(2)[1])
assert torch.equal(outputs, output_ref.chunk(2)[1])
elif is_last_rank:
assert torch.equal(loss, loss_ref.chunk(2)[0])
assert torch.equal(outputs, output_ref.chunk(2)[0])
else:
assert loss is None
assert outputs is None
def test_dualpipe(ngpus):
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
if __name__ == "__main__":
num_gpus = torch.cuda.device_count() // 2 * 2
for ngpus in range(num_gpus, 0, -2):
test_dualpipe(ngpus)

BIN
images/schedules.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

18
setup.py Normal file
View File

@ -0,0 +1,18 @@
from datetime import datetime
import subprocess
from setuptools import setup
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
now = datetime.now()
date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
rev = '+' + date_time_str
setup(
name="dualpipe",
version="1.0.0" + rev,
)