commit fbe0ac0d6eedee65495d704aaf1a1b3a95e51dea Author: ljss <450993438@qq.com> Date: Thu Feb 20 16:36:16 2025 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9a2c0d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +build +*.egg-info/ +__pycache__/ +dist/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5c48bdc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1b5f0f9 --- /dev/null +++ b/README.md @@ -0,0 +1,56 @@ +# DualPipe + +DualPipe is an innovative bidirectional pipeline parallism algorithm introduced in the [DeepSeek-V3 Technical Report](https://arxiv.org/pdf/2412.19437). It achieves full overlap of forward and backward computation-communication phases, also reducing pipeline bubbles. For detailed information on computation-communication overlap, please refer to the [profile data](https://github.com/deepseek-ai/profile-data). + +### Schedules + +![schedules](images/schedules.png) + +Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions. +The micro-batches in the reverse direction are symmetric to those in the forward direction, so +we omit their batch ID for illustration simplicity. Two cells enclosed by a shared black border +have mutually overlapped computation and communication + +### Pipeline Bubbles and Memory Usage Comparison + +| Method | Bubble | Parameter | Activation | +|-------------|---------------------------------|-----------|------------| +| 1F1B | (PP-1)(𝐹+𝑊) | 1× | PP | +| ZB1P | (PP-1)(𝐹+𝐵-2𝑊) | 1× | PP | +| DualPipe | (PP/2-1)(𝐹&𝐵+𝐵-3𝑊) | 2× | PP+1 | + +𝐹 denotes the execution time of a forward chunk, 𝐵 denotes the execution time of a +full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵 +denotes the execution time of two mutually overlapped forward and backward chunks. + +## Quick Start + +The usage is shown in the following example: + +```bash +python example.py +``` + +Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module. + +## Requirements + +- PyTorch 2.0 and above + +## Developers + +DualPipe was created and developed by Jiashi Li and Chengqi Deng and Wenfeng Liang. + +## Citation + +```bibtex +@misc{deepseekai2024deepseekv3technicalreport, + title={DeepSeek-V3 Technical Report}, + author={DeepSeek-AI}, + year={2024}, + eprint={2412.19437}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2412.19437}, +} +``` diff --git a/dualpipe/__init__.py b/dualpipe/__init__.py new file mode 100644 index 0000000..5359ba3 --- /dev/null +++ b/dualpipe/__init__.py @@ -0,0 +1,17 @@ +__version__ = "1.0.0" + +from dualpipe.dualpipe import ( + DualPipe, + WeightGradStore, +) +from dualpipe.comm import ( + set_p2p_tensor_shapes, + set_p2p_tensor_dtype, +) + +__all__ = [ + DualPipe, + WeightGradStore, + set_p2p_tensor_shapes, + set_p2p_tensor_dtype, +] diff --git a/dualpipe/comm.py b/dualpipe/comm.py new file mode 100644 index 0000000..e779a77 --- /dev/null +++ b/dualpipe/comm.py @@ -0,0 +1,38 @@ +from typing import List, Tuple + +import torch +import torch.distributed as dist + + +TENSOR_SHAPES: List[Tuple[int]] = None +TENSOR_DTYPE: torch.dtype = None + + +def set_p2p_tensor_shapes(shapes: List[Tuple[int]]): + global TENSOR_SHAPES + TENSOR_SHAPES = shapes + + +def set_p2p_tensor_dtype(dtype: torch.dtype): + global TENSOR_DTYPE + TENSOR_DTYPE = dtype + + +def build_from_tensor_shapes(): + return [torch.empty(s, dtype=TENSOR_DTYPE, device="cuda", requires_grad=True) for s in TENSOR_SHAPES] + + +def append_irecv(ops: List[dist.P2POp], src: int, group: dist.ProcessGroup) -> List[torch.Tensor]: + tensors = build_from_tensor_shapes() + src = dist.distributed_c10d.get_global_rank(group, src) + for tensor in tensors: + if tensor is not None: + ops.append(dist.P2POp(dist.irecv, tensor, src)) + return tensors + + +def append_isend(ops: List[dist.P2POp], tensors: List[torch.Tensor], dst: int, group: dist.ProcessGroup) -> None: + dst = dist.distributed_c10d.get_global_rank(group, dst) + for tensor in tensors: + if tensor is not None: + ops.append(dist.P2POp(dist.isend, tensor, dst)) diff --git a/dualpipe/dualpipe.py b/dualpipe/dualpipe.py new file mode 100644 index 0000000..099f84f --- /dev/null +++ b/dualpipe/dualpipe.py @@ -0,0 +1,440 @@ +from typing import Tuple, List, Union, Callable, Optional + +import torch +import torch.nn as nn +import torch.distributed as dist + +import dualpipe.comm as comm +from dualpipe.utils import WeightGradStore, run_backward, scatter, gather + + +class DualPipe(nn.Module): + def __init__( + self, + modules: Tuple[nn.Module, nn.Module], + batch_dim: int = 0, + process_group: Optional[dist.ProcessGroup] = None, + rank_mapping: Optional[List[int]] = None, + ) -> None: + super().__init__() + + assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device()) + self.module = nn.ModuleList(modules) + self.overlaped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlaped_forward_backward") + self.batch_dim = batch_dim + self.group = process_group or dist.distributed_c10d._get_default_group() + self.num_ranks = self.group.size() + + # rank_mapping: Map rank in process_group to actual pp rank. + # rank_inverse_mapping: Map actual pp rank to rank in process_group. + if rank_mapping is None: + rank_mapping = list(range(self.num_ranks)) + rank_inverse_mapping = [None] * (self.num_ranks + 1) + for i in range(self.num_ranks): + rank_inverse_mapping[rank_mapping[i]] = i + + self.rank = rank_mapping[self.group.rank()] + self.first_rank = rank_inverse_mapping[0] + self.prev_rank = rank_inverse_mapping[self.rank - 1] + self.next_rank = rank_inverse_mapping[self.rank + 1] + self.last_rank = rank_inverse_mapping[self.num_ranks - 1] + + self.is_first_rank = self.rank == 0 + self.is_last_rank = self.rank == self.num_ranks - 1 + self.is_in_second_half = self.rank >= self.num_ranks // 2 + self.is_middle_rank = (self.rank == self.num_ranks // 2 - 1) or (self.rank == self.num_ranks // 2) + + def _reset_states(self) -> None: + WeightGradStore.clear() + + self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], []) + self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None + self.loss_chunks: List[torch.Tensor] = [] + self.criterion: Callable = None + + self.current_f_chunk_id: List[int] = [0, 0] + self.current_b_chunk_id: List[int] = [0, 0] + self.current_send_f_chunk_id: List[int] = [0, 0] + self.current_send_b_chunk_id: List[int] = [0, 0] + self.current_recv_f_chunk_id: List[int] = [0, 0] + self.current_recv_b_chunk_id: List[int] = [0, 0] + self.comm_ops: List[dist.P2POp] = [] + self.to_free: List[torch.Tensor] = [] + + def _forward_compute_chunk(self, phase: int) -> None: + phase ^= self.is_in_second_half + chunk_id = self.current_f_chunk_id[phase] + self.current_f_chunk_id[phase] += 1 + inputs = self.input_chunks[phase][chunk_id] + if self.forward_only: + self.input_chunks[phase][chunk_id] = None + + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + + outputs = self.module[phase](*inputs) + outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs + if is_last_stage and self.criterion is not None: + labels = self.labels[phase][chunk_id] + loss = self.criterion(*outputs, *labels) + self.loss_chunks.append(loss) + + if (not is_last_stage) or self.return_outputs: + self.output_chunks[phase].append(outputs) + + def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + chunk_id = self.current_b_chunk_id[phase] + self.current_b_chunk_id[phase] += 1 + + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + + WeightGradStore.enabled = enable_zb + if is_last_stage: + loss = self.loss_chunks[chunk_id] + loss.backward() + loss.detach_() + else: + outputs = self.output_chunks[phase][chunk_id] + if not self.return_outputs: + self.output_chunks[phase][chunk_id] = None + output_grads = self.output_grad_chunks[phase][chunk_id] + self.output_grad_chunks[phase][chunk_id] = None + non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None] + outputs, output_grads = list(zip(*non_empty)) + if len(outputs) > 0: + run_backward(outputs, output_grads) + WeightGradStore.enabled = False + if enable_zb: + WeightGradStore.flush() + + inputs = self.input_chunks[phase][chunk_id] + self.input_chunks[phase][chunk_id] = None + input_grads = [t.grad for t in inputs] + self.input_grad_chunks[phase].append(input_grads) + + def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None: + if self.forward_only: + self._forward_compute_chunk(phase0) + return + + if not self.overlaped_forward_backward: + self._forward_compute_chunk(phase0) + self._backward_compute_chunk(phase1) + return + + # pre-forward + phase0 ^= self.is_in_second_half + chunk_id0 = self.current_f_chunk_id[phase0] + self.current_f_chunk_id[phase0] += 1 + module0 = self.module[phase0] + inputs0 = self.input_chunks[phase0][chunk_id0] + is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0) + + if is_last_stage0 and self.criterion is not None: + labels0 = self.labels[phase0][chunk_id0] + criterion0 = self.criterion + else: + labels0 = [] + criterion0 = None + + # pre-backward + phase1 ^= self.is_in_second_half + chunk_id1 = self.current_b_chunk_id[phase1] + self.current_b_chunk_id[phase1] += 1 + module1 = self.module[phase1] + is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0) + + if is_last_stage1: + loss1 = self.loss_chunks[chunk_id1] + outputs1 = [] + output_grads1 = [] + else: + loss1 = None + outputs1 = self.output_chunks[phase1][chunk_id1] + if not self.return_outputs: + self.output_chunks[phase1][chunk_id1] = None + output_grads1 = self.output_grad_chunks[phase1][chunk_id1] + self.output_grad_chunks[phase1][chunk_id1] = None + non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None] + outputs1, output_grads1 = list(zip(*non_empty)) + + # forward & backward + outputs0, loss0 = type(module0).overlaped_forward_backward( + module0, inputs0, criterion0, labels0, + module1, loss1, outputs1, output_grads1, + ) + + # post-forward + if (not is_last_stage0) or self.return_outputs: + self.output_chunks[phase0].append(outputs0) + if is_last_stage0 and self.criterion is not None: + self.loss_chunks.append(loss0) + + # post-backward + inputs = self.input_chunks[phase1][chunk_id1] + self.input_chunks[phase1][chunk_id1] = None + input_grads1 = [t.grad for t in inputs] + self.input_grad_chunks[phase1].append(input_grads1) + + def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_forward(phase) + self._commit_and_wait_comm() + + self._forward_compute_chunk(phase) + + if send: + self._send_forward(phase) + + def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None: + if recv: + self._recv_backward(phase) + self._commit_and_wait_comm() + + self._backward_compute_chunk(phase, enable_zb) + + if send: + self._send_backward(phase) + + def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None: + if recv0: + self._recv_forward(phase0) + self._recv_backward(phase1) + self._commit_and_wait_comm() + + self._forward_backward_compute_chunk(phase0, phase1) + + self._send_forward(phase0) + self._send_backward(phase1) + + def _weight_chunk(self) -> None: + if self.forward_only: + return + + self._commit_and_wait_comm() + + # Assume FIFO + WeightGradStore.pop() + + def _free_tensors(self) -> None: + for tensor in self.to_free: + assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}" + tensor.data = torch.Tensor() + self.to_free = [] + + def _recv_forward(self, phase: int) -> None: + phase ^= self.is_in_second_half + is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1) + if is_first_stage: + return + + self.current_recv_f_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group) + self.input_chunks[phase].append(tensors) + + def _send_forward(self, phase: int) -> None: + phase ^= self.is_in_second_half + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + if is_last_stage: + return + + chunk_id = self.current_send_f_chunk_id[phase] + self.current_send_f_chunk_id[phase] += 1 + tensors = self.output_chunks[phase][chunk_id] + + comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group) + + if not self.return_outputs: + self.to_free.extend(tensors) + + def _recv_backward(self, phase: int) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0) + if is_last_stage: + return + + self.current_recv_b_chunk_id[phase] += 1 + tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group) + self.output_grad_chunks[phase].append(tensors) + + def _send_backward(self, phase: int) -> None: + if self.forward_only: + return + + phase ^= self.is_in_second_half + is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1) + if is_first_stage: + return + + chunk_id = self.current_send_b_chunk_id[phase] + self.current_send_b_chunk_id[phase] += 1 + tensors = self.input_grad_chunks[phase][chunk_id] + self.input_grad_chunks[phase][chunk_id] = None + + comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group) + + def _commit_and_wait_comm(self) -> None: + if not self.comm_ops: + return + reqs = dist.batch_isend_irecv(self.comm_ops) + for req in reqs: + req.wait() + self.comm_ops = [] + self._free_tensors() + + def step( + self, + *inputs: Optional[torch.Tensor], + num_chunks: int = 0, + criterion: Optional[Callable] = None, + labels: List[Optional[torch.Tensor]] = [], + return_outputs: bool = False, + ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]: + """ + Execute a traning or inference step. + + Arguments: + *inputs: Module inputs. Required only on the first/last ranks. + num_chunks: The number of micro-batches. + criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks. + labels: Labels of the loss function. Required only on the first/last ranks. + labels on the first rank corresponds to inputs on the last rank. + labels on the last rank corresponds to inputs on the first rank. + return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``. + + Returns: (loss, outputs) + loss: Loss for the batch. + loss on the first rank corresponds to inputs on the last rank. + loss on the last rank corresponds to inputs on the first rank. + Otherwise: ``None``. + outputs: Returned only if ``return_outputs=True``. + outputs on the first rank corresponds to inputs on the last rank. + outputs on the last rank corresponds to inputs on the first rank. + Otherwise: ``None``. + + """ + assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \ + "You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before doing a step." + self.forward_only = not torch.is_grad_enabled() + self.return_outputs = return_outputs + + rank = self.rank + num_ranks = self.num_ranks + assert num_ranks % 2 == 0 + assert num_chunks > 0 and num_chunks % 2 == 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}" + num_half_ranks = num_ranks // 2 + half_rank = min(rank, num_ranks - 1 - rank) + half_num_chunks = num_chunks // 2 + self.num_half_ranks = num_half_ranks + self.half_rank = half_rank + + if not self.forward_only and (self.is_first_rank or self.is_last_rank): + assert criterion is not None + + self._reset_states() + + inputs = scatter(inputs, half_num_chunks, self.batch_dim) + labels = scatter(labels, half_num_chunks, self.batch_dim) + if self.is_first_rank: + self.input_chunks = (inputs, []) + self.labels = ([], labels) + elif self.is_last_rank: + self.input_chunks = ([], inputs) + self.labels = (labels, []) + self.criterion = criterion + + # For the fisrt half of the ranks: phase 0 means forward direction, phase 1 means reverse direction. + # For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction. + + # Step 1: nF0 + step_1 = (num_half_ranks - half_rank - 1) * 2 + for i in range(step_1): + self._forward_chunk(0) + + # Step 2: nF0F1 + step_2 = half_rank + 1 + self._recv_forward(0) + for i in range(step_2): + self._forward_chunk(0, recv=False, send=self.is_middle_rank) + self._recv_forward(0) + self._forward_chunk(1, send=(not self.is_middle_rank) or (i < step_2 - 1)) + if not self.is_middle_rank: + self._send_forward(0) + + # Step 3: nB1W1F1 (Use zero bubble) + step_3 = num_half_ranks - half_rank - 1 + for i in range(step_3): + self._backward_chunk(1, enable_zb=True) + self._recv_forward(1) + self._weight_chunk() + self._forward_chunk(1, recv=False) + + # Step 4 (Main step): nF0B1F1B0 + step_4 = half_num_chunks - num_ranks + half_rank + 1 + for i in range(step_4): + if i == 0: + if self.is_middle_rank: + # NOTE: We don't overlap these two chunks to further reduce bubble size. + self._forward_chunk(0, recv=False, send=False) + self._send_forward(1) + self._backward_chunk(1, send=False) + self._send_forward(0) + self._send_backward(1) + else: + self._forward_backward_chunk(0, 1, recv0=False) + else: + self._forward_backward_chunk(0, 1) + self._forward_backward_chunk(1, 0) + + # Step 5: nB1F1B0 + step_5 = num_half_ranks - half_rank - 1 + for i in range(step_5): + self._backward_chunk(1) + self._forward_backward_chunk(1, 0) + + # Step 6: nB1B0 (The second half of the chunks use zero bubble) + step_6 = half_rank + 1 + enable_zb = False + for i in range(step_6): + if i == step_6 // 2 and half_rank % 2 == 1: + enable_zb = True + self._backward_chunk(1, enable_zb=enable_zb) + if i == step_6 // 2 and half_rank % 2 == 0: + enable_zb = True + self._backward_chunk(0, enable_zb=enable_zb) + + # Step 7: nWB0 (Use zero bubble) + step_7 = num_half_ranks - half_rank - 1 + for i in range(step_7): + self._weight_chunk() + self._backward_chunk(0, enable_zb=True) + + # Step 8: nW + step_8 = half_rank + 1 + for i in range(step_8): + self._weight_chunk() + assert WeightGradStore.funcs_queue.empty() + + self._commit_and_wait_comm() + + loss, outputs = None, None + if self.is_first_rank or self.is_last_rank: + if criterion is not None: + loss = torch.stack(self.loss_chunks) + if return_outputs: + outputs = gather(self.output_chunks[self.is_first_rank], self.batch_dim) + if len(outputs) == 1: + outputs = outputs[0] + + self._reset_states() + + return loss, outputs diff --git a/dualpipe/utils.py b/dualpipe/utils.py new file mode 100644 index 0000000..cefc52b --- /dev/null +++ b/dualpipe/utils.py @@ -0,0 +1,80 @@ +import queue +from typing import List, Callable + +import torch +from torch.autograd import Variable + + +class WeightGradStore: + + enabled: bool = False + cache: List[Callable] = [] + funcs_queue = queue.Queue() + + @classmethod + def put(cls, func: Callable) -> None: + cls.cache.append(func) + + @classmethod + def flush(cls) -> None: + cls.funcs_queue.put(cls.cache) + cls.cache = [] + + @classmethod + def pop(cls) -> None: + assert not cls.funcs_queue.empty(), "Pop empty queue." + funcs = cls.funcs_queue.get() + for func in funcs: + func() + + @classmethod + def clear(cls) -> None: + cls.cache = [] + cls.funcs_queue = queue.Queue() + + +def run_backward(tensors: List[torch.Tensor], grad_tensors: List[torch.Tensor]) -> None: + kwargs = dict( + keep_graph=False, + create_graph=False, + allow_unreachable=True, + accumulate_grad=True, + ) + Variable._execution_engine.run_backward(tensors, grad_tensors, **kwargs) + + +def chunk_tensor(x, chunks, dim): + if x is None: + return [None for _ in range(chunks)] + return x.tensor_split(chunks, dim=dim) + + +def cat_tensor(x, dim): + if (isinstance(x, tuple) or isinstance(x, list)): + if len(x) == 1: + return x[0] + elif x[0] is None: + assert all(y is None for y in x) + return None + return torch.cat(x, dim=dim) + + +def scatter(inputs, chunks, dim): + assert isinstance(inputs, (torch.Tensor, tuple, list)) + if isinstance(inputs, torch.Tensor): + inputs = (inputs,) + assert all(x is None or isinstance(x, torch.Tensor) for x in inputs) + inputs = [chunk_tensor(x, chunks, dim) for x in inputs] + microbatches = [microbatch for microbatch in zip(*inputs)] + if len(microbatches) == 0: + microbatches = [() for _ in range(chunks)] + return microbatches + + +def gather(micro_outputs, dim): + assert isinstance(micro_outputs[0], (torch.Tensor, tuple, list)) + if isinstance(micro_outputs[0], torch.Tensor): + micro_outputs = [(x,) for x in micro_outputs] + outputs = [x for x in zip(*micro_outputs)] + outputs = tuple(cat_tensor(x, dim=dim) for x in outputs) + return outputs diff --git a/example.py b/example.py new file mode 100644 index 0000000..e78984a --- /dev/null +++ b/example.py @@ -0,0 +1,202 @@ +from typing import List, Optional, Callable, Tuple +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from dualpipe import DualPipe, set_p2p_tensor_shapes, set_p2p_tensor_dtype +from dualpipe.utils import WeightGradStore, run_backward + + +class LinearFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + output = F.linear(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + if weight.grad is None: + weight.grad = torch.zeros_like(weight) + + def grad_weight_fn(): + weight.grad += grad_output.flatten(0, -2).T @ input.flatten(0, -2) + + if WeightGradStore.enabled: + WeightGradStore.put(grad_weight_fn) + else: + grad_weight_fn() + grad_input = grad_output @ weight + return grad_input, None + + +class MyLinear(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return LinearFunc.apply(input, self.weight) + + +class PipelineStage(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.linear1 = MyLinear(hidden_size, hidden_size * 4, bias=False) + self.linear2 = MyLinear(hidden_size * 4, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = F.gelu(x) + x = self.linear2(x) + return x + + @classmethod + def overlaped_forward_backward( + cls, + module0: "PipelineStage", + inputs0: List[torch.Tensor], + criterion0: Optional[Callable], + labels0: Optional[List[torch.Tensor]], + module1: "PipelineStage", + loss1: Optional[torch.Tensor], + outputs1: Optional[List[torch.Tensor]], + output_grads1: Optional[List[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + You should implement custom forward-backward overlap strategy. + The code below is just an example. + """ + outputs0 = module0(*inputs0) + outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0 + if criterion0 is not None: + loss0 = criterion0(*outputs0, *labels0) + else: + loss0 = None + + if loss1 is not None: + loss1.backward() + loss1.detach_() + else: + run_backward(outputs1, output_grads1) + + return outputs0, loss0 + + +def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return F.mse_loss(output, target).clone() + + +def ref_step(x, l, model, chunks): + ys, losses = [], [] + for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)): + micro_y = model(micro_x) + loss = criterion(micro_y, micro_l) + loss.backward() + ys.append(micro_y) + losses.append(loss) + y = torch.cat(ys, 0) + loss = torch.stack(losses) + return loss, y + + +def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item() + return cos_diff + + +def main(rank, pp_size): + is_first_rank = rank == 0 + is_last_rank = rank == pp_size - 1 + dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank) + torch.cuda.set_device(rank) + torch.set_default_device(f"cuda:{rank}") + torch.manual_seed(233) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + num_chunks = 20 + micro_batch_size = 3 + seq_len = 256 + hidden_size = 512 + if is_first_rank: + print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}", flush=True) + set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)]) + set_p2p_tensor_dtype(torch.float32) + + # Create a model and partition it for each process + full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size)]) + + # Full inputs + full_x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size) + full_l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size) + + # Reference step + loss_ref, output_ref = ref_step(full_x, full_l, full_modules, num_chunks) + + # DualPipe + local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size - 1 - rank]) + local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size)) + local_modules[0].load_state_dict(local_full_modules[0].state_dict()) + local_modules[1].load_state_dict(local_full_modules[1].state_dict()) + dualpipe_model = DualPipe(local_modules) + + # DualPipe inputs + if is_first_rank: + x = full_x.chunk(2)[0] + l = full_l.chunk(2)[1] + elif is_last_rank: + x = full_x.chunk(2)[1] + l = full_l.chunk(2)[0] + else: + x = None + l = None + + # Training step + loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=False) + + # Check loss + if is_first_rank: + assert torch.equal(loss, loss_ref.chunk(2)[1]) + elif is_last_rank: + assert torch.equal(loss, loss_ref.chunk(2)[0]) + else: + assert loss is None + assert outputs is None + + # Check grads + for (p0, p1) in zip(local_modules[0].parameters(), local_modules[1].parameters()): + p0all = torch.empty(pp_size, *p0.shape) + p1all = torch.empty(pp_size, *p1.shape) + dist.all_gather_into_tensor(p0all, p0.grad) + dist.all_gather_into_tensor(p1all, p1.grad) + p0.grad += p1all[pp_size - 1 - rank] + p1.grad += p0all[pp_size - 1 - rank] + for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()): + assert cal_diff(p.grad, p_ref.grad) < 1e-13 + dualpipe_model.zero_grad() + + # Inference step + with torch.no_grad(): + loss, outputs = dualpipe_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=True) + + # Check loss and outputs + if is_first_rank: + assert torch.equal(loss, loss_ref.chunk(2)[1]) + assert torch.equal(outputs, output_ref.chunk(2)[1]) + elif is_last_rank: + assert torch.equal(loss, loss_ref.chunk(2)[0]) + assert torch.equal(outputs, output_ref.chunk(2)[0]) + else: + assert loss is None + assert outputs is None + + +def test_dualpipe(ngpus): + torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True) + + +if __name__ == "__main__": + num_gpus = torch.cuda.device_count() // 2 * 2 + for ngpus in range(num_gpus, 0, -2): + test_dualpipe(ngpus) diff --git a/images/schedules.png b/images/schedules.png new file mode 100644 index 0000000..88734e8 Binary files /dev/null and b/images/schedules.png differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..04a5488 --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +from datetime import datetime +import subprocess + +from setuptools import setup + + +try: + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() +except Exception as _: + now = datetime.now() + date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") + rev = '+' + date_time_str + +setup( + name="dualpipe", + version="1.0.0" + rev, +)