mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-06-26 18:16:46 +00:00
add dualpipev
This commit is contained in:
parent
036c14e7f6
commit
8ec5883b30
13
README.md
13
README.md
@ -4,7 +4,7 @@ DualPipe is an innovative bidirectional pipeline parallelism algorithm introduce
|
|||||||
|
|
||||||
### Schedules
|
### Schedules
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions.
|
Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions.
|
||||||
The micro-batches in the reverse direction are symmetric to those in the forward direction, so
|
The micro-batches in the reverse direction are symmetric to those in the forward direction, so
|
||||||
@ -23,12 +23,21 @@ have mutually overlapped computation and communication
|
|||||||
full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵
|
full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵
|
||||||
denotes the execution time of two mutually overlapped forward and backward chunks.
|
denotes the execution time of two mutually overlapped forward and backward chunks.
|
||||||
|
|
||||||
|
# DualPipeV
|
||||||
|
|
||||||
|
DualPipeV is a concise V-shape schedule derived from DualPipe using a "cut-in-half" procedure, introduced by Sea AI Lab as "Cut-in-half" in their [blog post](https://hackmd.io/@ufotalent/r1lVXsa9Jg). Thanks to them for this efficient schedule!
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Example DualPipeV scheduling for 4 PP ranks and 10 micro-batches.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
The usage is shown in the following example:
|
The usage is shown in the following example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python example.py
|
python examples/example_dualpipe.py
|
||||||
|
python examples/example_dualpipev.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module.
|
Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module.
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
__version__ = "1.0.0"
|
__version__ = "1.0.0"
|
||||||
|
|
||||||
from dualpipe.dualpipe import (
|
from dualpipe.dualpipe import DualPipe
|
||||||
DualPipe,
|
from dualpipe.dualpipev import DualPipeV
|
||||||
WeightGradStore,
|
|
||||||
)
|
|
||||||
from dualpipe.comm import (
|
from dualpipe.comm import (
|
||||||
set_p2p_tensor_shapes,
|
set_p2p_tensor_shapes,
|
||||||
set_p2p_tensor_dtype,
|
set_p2p_tensor_dtype,
|
||||||
)
|
)
|
||||||
|
from dualpipe.utils import WeightGradStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
DualPipe,
|
DualPipe,
|
||||||
|
DualPipeV,
|
||||||
WeightGradStore,
|
WeightGradStore,
|
||||||
set_p2p_tensor_shapes,
|
set_p2p_tensor_shapes,
|
||||||
set_p2p_tensor_dtype,
|
set_p2p_tensor_dtype,
|
||||||
|
411
dualpipe/dualpipev.py
Normal file
411
dualpipe/dualpipev.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
from typing import Tuple, List, Union, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
import dualpipe.comm as comm
|
||||||
|
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
|
||||||
|
|
||||||
|
|
||||||
|
class DualPipeV(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
modules: Tuple[nn.Module, nn.Module],
|
||||||
|
batch_dim: int = 0,
|
||||||
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
rank_mapping: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
|
||||||
|
self.module = nn.ModuleList(modules)
|
||||||
|
self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
|
||||||
|
self.batch_dim = batch_dim
|
||||||
|
self.group = process_group or dist.distributed_c10d._get_default_group()
|
||||||
|
self.num_ranks = self.group.size()
|
||||||
|
|
||||||
|
# rank_mapping: Map rank in process_group to actual pp rank.
|
||||||
|
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
|
||||||
|
if rank_mapping is None:
|
||||||
|
rank_mapping = list(range(self.num_ranks))
|
||||||
|
rank_inverse_mapping = [None] * (self.num_ranks + 1)
|
||||||
|
for i in range(self.num_ranks):
|
||||||
|
rank_inverse_mapping[rank_mapping[i]] = i
|
||||||
|
|
||||||
|
self.rank = rank_mapping[self.group.rank()]
|
||||||
|
self.prev_rank = rank_inverse_mapping[self.rank - 1]
|
||||||
|
self.next_rank = rank_inverse_mapping[self.rank + 1]
|
||||||
|
|
||||||
|
self.is_first_rank = self.rank == 0
|
||||||
|
self.is_last_rank = self.rank == self.num_ranks - 1
|
||||||
|
|
||||||
|
def _reset_states(self) -> None:
|
||||||
|
WeightGradStore.clear()
|
||||||
|
|
||||||
|
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||||
|
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||||
|
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||||
|
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
||||||
|
self.labels: List[List[torch.Tensor]] = None
|
||||||
|
self.loss_chunks: List[torch.Tensor] = []
|
||||||
|
self.criterion: Callable = None
|
||||||
|
|
||||||
|
self.current_f_chunk_id: List[int] = [0, 0]
|
||||||
|
self.current_b_chunk_id: List[int] = [0, 0]
|
||||||
|
self.current_send_f_chunk_id: List[int] = [0, 0]
|
||||||
|
self.current_send_b_chunk_id: List[int] = [0, 0]
|
||||||
|
self.current_recv_f_chunk_id: List[int] = [0, 0]
|
||||||
|
self.current_recv_b_chunk_id: List[int] = [0, 0]
|
||||||
|
self.comm_ops: List[dist.P2POp] = []
|
||||||
|
self.to_free: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def _forward_compute_chunk(self, phase: int) -> None:
|
||||||
|
chunk_id = self.current_f_chunk_id[phase]
|
||||||
|
self.current_f_chunk_id[phase] += 1
|
||||||
|
inputs = self.input_chunks[phase][chunk_id]
|
||||||
|
if self.forward_only:
|
||||||
|
self.input_chunks[phase][chunk_id] = None
|
||||||
|
|
||||||
|
is_last_stage = (self.is_first_rank and phase == 1)
|
||||||
|
|
||||||
|
outputs = self.module[phase](*inputs)
|
||||||
|
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
|
||||||
|
if is_last_stage and self.criterion is not None:
|
||||||
|
labels = self.labels[chunk_id]
|
||||||
|
loss = self.criterion(*outputs, *labels)
|
||||||
|
self.loss_chunks.append(loss)
|
||||||
|
|
||||||
|
if self.is_last_rank and phase == 0:
|
||||||
|
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs])
|
||||||
|
if (not is_last_stage) or self.return_outputs:
|
||||||
|
self.output_chunks[phase].append(outputs)
|
||||||
|
|
||||||
|
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
|
||||||
|
if self.forward_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk_id = self.current_b_chunk_id[phase]
|
||||||
|
self.current_b_chunk_id[phase] += 1
|
||||||
|
|
||||||
|
is_last_stage = (self.is_first_rank and phase == 1)
|
||||||
|
|
||||||
|
WeightGradStore.enabled = enable_zb
|
||||||
|
if is_last_stage:
|
||||||
|
loss = self.loss_chunks[chunk_id]
|
||||||
|
loss.backward()
|
||||||
|
loss.detach_()
|
||||||
|
else:
|
||||||
|
outputs = self.output_chunks[phase][chunk_id]
|
||||||
|
if not self.return_outputs:
|
||||||
|
self.output_chunks[phase][chunk_id] = None
|
||||||
|
output_grads = self.output_grad_chunks[phase][chunk_id]
|
||||||
|
self.output_grad_chunks[phase][chunk_id] = None
|
||||||
|
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
|
||||||
|
outputs, output_grads = list(zip(*non_empty))
|
||||||
|
if len(outputs) > 0:
|
||||||
|
run_backward(outputs, output_grads)
|
||||||
|
WeightGradStore.enabled = False
|
||||||
|
if enable_zb:
|
||||||
|
WeightGradStore.flush()
|
||||||
|
|
||||||
|
inputs = self.input_chunks[phase][chunk_id]
|
||||||
|
self.input_chunks[phase][chunk_id] = None
|
||||||
|
input_grads = [t.grad for t in inputs]
|
||||||
|
if self.is_last_rank and phase == 1:
|
||||||
|
self.output_grad_chunks[0].append(input_grads)
|
||||||
|
else:
|
||||||
|
self.input_grad_chunks[phase].append(input_grads)
|
||||||
|
|
||||||
|
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
|
||||||
|
if self.forward_only:
|
||||||
|
self._forward_compute_chunk(phase0)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.overlapped_forward_backward:
|
||||||
|
self._forward_compute_chunk(phase0)
|
||||||
|
self._backward_compute_chunk(phase1)
|
||||||
|
return
|
||||||
|
|
||||||
|
# pre-forward
|
||||||
|
chunk_id0 = self.current_f_chunk_id[phase0]
|
||||||
|
self.current_f_chunk_id[phase0] += 1
|
||||||
|
module0 = self.module[phase0]
|
||||||
|
inputs0 = self.input_chunks[phase0][chunk_id0]
|
||||||
|
is_last_stage0 = (self.is_first_rank and phase0 == 1)
|
||||||
|
|
||||||
|
if is_last_stage0 and self.criterion is not None:
|
||||||
|
labels0 = self.labels[chunk_id0]
|
||||||
|
criterion0 = self.criterion
|
||||||
|
else:
|
||||||
|
labels0 = []
|
||||||
|
criterion0 = None
|
||||||
|
|
||||||
|
# pre-backward
|
||||||
|
chunk_id1 = self.current_b_chunk_id[phase1]
|
||||||
|
self.current_b_chunk_id[phase1] += 1
|
||||||
|
module1 = self.module[phase1]
|
||||||
|
is_last_stage1 = (self.is_first_rank and phase1 == 1)
|
||||||
|
|
||||||
|
if is_last_stage1:
|
||||||
|
loss1 = self.loss_chunks[chunk_id1]
|
||||||
|
outputs1 = []
|
||||||
|
output_grads1 = []
|
||||||
|
else:
|
||||||
|
loss1 = None
|
||||||
|
outputs1 = self.output_chunks[phase1][chunk_id1]
|
||||||
|
if not self.return_outputs:
|
||||||
|
self.output_chunks[phase1][chunk_id1] = None
|
||||||
|
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
|
||||||
|
self.output_grad_chunks[phase1][chunk_id1] = None
|
||||||
|
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
|
||||||
|
outputs1, output_grads1 = list(zip(*non_empty))
|
||||||
|
|
||||||
|
# forward & backward
|
||||||
|
outputs0, loss0 = type(module0).overlapped_forward_backward(
|
||||||
|
module0, inputs0, criterion0, labels0,
|
||||||
|
module1, loss1, outputs1, output_grads1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# post-forward
|
||||||
|
if self.is_last_rank and phase0 == 0:
|
||||||
|
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs0])
|
||||||
|
if (not is_last_stage0) or self.return_outputs:
|
||||||
|
self.output_chunks[phase0].append(outputs0)
|
||||||
|
if is_last_stage0 and self.criterion is not None:
|
||||||
|
self.loss_chunks.append(loss0)
|
||||||
|
|
||||||
|
# post-backward
|
||||||
|
inputs = self.input_chunks[phase1][chunk_id1]
|
||||||
|
self.input_chunks[phase1][chunk_id1] = None
|
||||||
|
input_grads1 = [t.grad for t in inputs]
|
||||||
|
if self.is_last_rank and phase1 == 1:
|
||||||
|
self.output_grad_chunks[0].append(input_grads1)
|
||||||
|
else:
|
||||||
|
self.input_grad_chunks[phase1].append(input_grads1)
|
||||||
|
|
||||||
|
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
|
||||||
|
if recv:
|
||||||
|
self._recv_forward(phase)
|
||||||
|
self._commit_and_wait_comm()
|
||||||
|
|
||||||
|
self._forward_compute_chunk(phase)
|
||||||
|
|
||||||
|
if send:
|
||||||
|
self._send_forward(phase)
|
||||||
|
|
||||||
|
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
|
||||||
|
if recv:
|
||||||
|
self._recv_backward(phase)
|
||||||
|
self._commit_and_wait_comm()
|
||||||
|
|
||||||
|
self._backward_compute_chunk(phase, enable_zb)
|
||||||
|
|
||||||
|
if send:
|
||||||
|
self._send_backward(phase)
|
||||||
|
|
||||||
|
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
|
||||||
|
if recv0:
|
||||||
|
self._recv_forward(phase0)
|
||||||
|
self._recv_backward(phase1)
|
||||||
|
self._commit_and_wait_comm()
|
||||||
|
|
||||||
|
self._forward_backward_compute_chunk(phase0, phase1)
|
||||||
|
|
||||||
|
self._send_forward(phase0)
|
||||||
|
self._send_backward(phase1)
|
||||||
|
|
||||||
|
def _weight_chunk(self) -> None:
|
||||||
|
if self.forward_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._commit_and_wait_comm()
|
||||||
|
|
||||||
|
# Assume FIFO
|
||||||
|
WeightGradStore.pop()
|
||||||
|
|
||||||
|
def _free_tensors(self) -> None:
|
||||||
|
for tensor in self.to_free:
|
||||||
|
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
|
||||||
|
tensor.data = torch.Tensor()
|
||||||
|
self.to_free = []
|
||||||
|
|
||||||
|
def _recv_forward(self, phase: int) -> None:
|
||||||
|
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_recv_f_chunk_id[phase] += 1
|
||||||
|
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
||||||
|
self.input_chunks[phase].append(tensors)
|
||||||
|
|
||||||
|
def _send_forward(self, phase: int) -> None:
|
||||||
|
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk_id = self.current_send_f_chunk_id[phase]
|
||||||
|
self.current_send_f_chunk_id[phase] += 1
|
||||||
|
tensors = self.output_chunks[phase][chunk_id]
|
||||||
|
|
||||||
|
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
||||||
|
|
||||||
|
if not self.return_outputs:
|
||||||
|
self.to_free.extend(tensors)
|
||||||
|
|
||||||
|
def _recv_backward(self, phase: int) -> None:
|
||||||
|
if self.forward_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_recv_b_chunk_id[phase] += 1
|
||||||
|
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
||||||
|
self.output_grad_chunks[phase].append(tensors)
|
||||||
|
|
||||||
|
def _send_backward(self, phase: int) -> None:
|
||||||
|
if self.forward_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
chunk_id = self.current_send_b_chunk_id[phase]
|
||||||
|
self.current_send_b_chunk_id[phase] += 1
|
||||||
|
tensors = self.input_grad_chunks[phase][chunk_id]
|
||||||
|
self.input_grad_chunks[phase][chunk_id] = None
|
||||||
|
|
||||||
|
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
||||||
|
|
||||||
|
def _commit_and_wait_comm(self) -> None:
|
||||||
|
if not self.comm_ops:
|
||||||
|
return
|
||||||
|
reqs = dist.batch_isend_irecv(self.comm_ops)
|
||||||
|
for req in reqs:
|
||||||
|
req.wait()
|
||||||
|
self.comm_ops = []
|
||||||
|
self._free_tensors()
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
*inputs: Optional[torch.Tensor],
|
||||||
|
num_chunks: int = 0,
|
||||||
|
criterion: Optional[Callable] = None,
|
||||||
|
labels: List[Optional[torch.Tensor]] = [],
|
||||||
|
return_outputs: bool = False,
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
|
||||||
|
"""
|
||||||
|
Execute a training or inference step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
*inputs: Module inputs. Required only on the first rank.
|
||||||
|
num_chunks: The number of micro-batches.
|
||||||
|
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank.
|
||||||
|
labels: Labels of the loss function. Required only on the first rank.
|
||||||
|
return_outputs: Whether to return outputs on the first rank. Default: ``False``.
|
||||||
|
|
||||||
|
Returns: (loss, outputs)
|
||||||
|
loss: Loss for the batch. Returned only on the first rank.
|
||||||
|
outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
|
||||||
|
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before executing a step."
|
||||||
|
self.forward_only = not torch.is_grad_enabled()
|
||||||
|
self.return_outputs = return_outputs
|
||||||
|
|
||||||
|
rank = self.rank
|
||||||
|
num_ranks = self.num_ranks
|
||||||
|
assert num_chunks > 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
|
||||||
|
|
||||||
|
if not self.forward_only and self.is_first_rank:
|
||||||
|
assert criterion is not None
|
||||||
|
|
||||||
|
self._reset_states()
|
||||||
|
|
||||||
|
if self.is_first_rank:
|
||||||
|
self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), [])
|
||||||
|
self.labels = scatter(labels, num_chunks, self.batch_dim)
|
||||||
|
self.criterion = criterion
|
||||||
|
|
||||||
|
# Step 1: nF0
|
||||||
|
step_1 = (num_ranks - rank - 1) * 2
|
||||||
|
for i in range(step_1):
|
||||||
|
self._forward_chunk(0)
|
||||||
|
|
||||||
|
# Step 2: nF0F1
|
||||||
|
step_2 = rank + 1
|
||||||
|
self._recv_forward(0)
|
||||||
|
for i in range(step_2):
|
||||||
|
self._forward_chunk(0, recv=False, send=False)
|
||||||
|
self._recv_forward(0)
|
||||||
|
self._forward_chunk(1, send=(not self.is_last_rank) or (i < step_2 - 1))
|
||||||
|
self._send_forward(0)
|
||||||
|
|
||||||
|
# Step 3: nB1W1F1 (Use zero bubble)
|
||||||
|
step_3 = num_ranks - rank - 1
|
||||||
|
for i in range(step_3):
|
||||||
|
self._backward_chunk(1, enable_zb=True)
|
||||||
|
self._recv_forward(1)
|
||||||
|
self._weight_chunk()
|
||||||
|
self._forward_chunk(1, recv=False)
|
||||||
|
|
||||||
|
# Step 4 (Main step): nF0B1F1B0
|
||||||
|
step_4 = num_chunks - num_ranks * 2 + rank + 1
|
||||||
|
for i in range(step_4):
|
||||||
|
if i == 0:
|
||||||
|
if self.is_last_rank:
|
||||||
|
# NOTE: We don't overlap these two chunks to further reduce bubble size.
|
||||||
|
self._forward_chunk(0, recv=False, send=False)
|
||||||
|
self._send_forward(1)
|
||||||
|
self._backward_chunk(1, send=False)
|
||||||
|
self._send_forward(0)
|
||||||
|
self._send_backward(1)
|
||||||
|
else:
|
||||||
|
self._forward_backward_chunk(0, 1, recv0=False)
|
||||||
|
else:
|
||||||
|
self._forward_backward_chunk(0, 1)
|
||||||
|
self._forward_backward_chunk(1, 0)
|
||||||
|
|
||||||
|
# Step 5: nB1F1B0
|
||||||
|
step_5 = num_ranks - rank - 1
|
||||||
|
for i in range(step_5):
|
||||||
|
self._backward_chunk(1)
|
||||||
|
self._forward_backward_chunk(1, 0)
|
||||||
|
|
||||||
|
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
|
||||||
|
step_6 = rank + 1
|
||||||
|
enable_zb = False
|
||||||
|
for i in range(step_6):
|
||||||
|
if i == step_6 // 2 and rank % 2 == 1:
|
||||||
|
enable_zb = True
|
||||||
|
self._backward_chunk(1, enable_zb=enable_zb)
|
||||||
|
if i == step_6 // 2 and rank % 2 == 0:
|
||||||
|
enable_zb = True
|
||||||
|
self._backward_chunk(0, enable_zb=enable_zb)
|
||||||
|
|
||||||
|
# Step 7: nWB0 (Use zero bubble)
|
||||||
|
step_7 = num_ranks - rank - 1
|
||||||
|
for i in range(step_7):
|
||||||
|
self._weight_chunk()
|
||||||
|
self._backward_chunk(0, enable_zb=True)
|
||||||
|
|
||||||
|
# Step 8: nW
|
||||||
|
step_8 = rank + 1
|
||||||
|
for i in range(step_8):
|
||||||
|
self._weight_chunk()
|
||||||
|
assert WeightGradStore.funcs_queue.empty()
|
||||||
|
|
||||||
|
self._commit_and_wait_comm()
|
||||||
|
|
||||||
|
loss, outputs = None, None
|
||||||
|
if self.is_first_rank:
|
||||||
|
if criterion is not None:
|
||||||
|
loss = torch.stack(self.loss_chunks)
|
||||||
|
if return_outputs:
|
||||||
|
outputs = gather(self.output_chunks[1], self.batch_dim)
|
||||||
|
if len(outputs) == 1:
|
||||||
|
outputs = outputs[0]
|
||||||
|
|
||||||
|
self._reset_states()
|
||||||
|
|
||||||
|
return loss, outputs
|
183
examples/example_dualpipev.py
Normal file
183
examples/example_dualpipev.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
from typing import List, Optional, Callable, Tuple
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from dualpipe import DualPipeV, set_p2p_tensor_shapes, set_p2p_tensor_dtype
|
||||||
|
from dualpipe.utils import WeightGradStore, run_backward
|
||||||
|
|
||||||
|
|
||||||
|
class LinearFunc(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input, weight):
|
||||||
|
ctx.save_for_backward(input, weight)
|
||||||
|
output = F.linear(input, weight)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input, weight = ctx.saved_tensors
|
||||||
|
if weight.grad is None:
|
||||||
|
weight.grad = torch.zeros_like(weight)
|
||||||
|
|
||||||
|
def grad_weight_fn():
|
||||||
|
weight.grad += grad_output.flatten(0, -2).T @ input.flatten(0, -2)
|
||||||
|
|
||||||
|
if WeightGradStore.enabled:
|
||||||
|
WeightGradStore.put(grad_weight_fn)
|
||||||
|
else:
|
||||||
|
grad_weight_fn()
|
||||||
|
grad_input = grad_output @ weight
|
||||||
|
return grad_input, None
|
||||||
|
|
||||||
|
|
||||||
|
class MyLinear(nn.Linear):
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return LinearFunc.apply(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = MyLinear(hidden_size, hidden_size * 4, bias=False)
|
||||||
|
self.linear2 = MyLinear(hidden_size * 4, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = F.gelu(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def overlapped_forward_backward(
|
||||||
|
cls,
|
||||||
|
module0: "PipelineStage",
|
||||||
|
inputs0: List[torch.Tensor],
|
||||||
|
criterion0: Optional[Callable],
|
||||||
|
labels0: Optional[List[torch.Tensor]],
|
||||||
|
module1: "PipelineStage",
|
||||||
|
loss1: Optional[torch.Tensor],
|
||||||
|
outputs1: Optional[List[torch.Tensor]],
|
||||||
|
output_grads1: Optional[List[torch.Tensor]],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
You should implement custom forward-backward overlap strategy.
|
||||||
|
The code below is just an example.
|
||||||
|
"""
|
||||||
|
outputs0 = module0(*inputs0)
|
||||||
|
outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0
|
||||||
|
if criterion0 is not None:
|
||||||
|
loss0 = criterion0(*outputs0, *labels0)
|
||||||
|
else:
|
||||||
|
loss0 = None
|
||||||
|
|
||||||
|
if loss1 is not None:
|
||||||
|
loss1.backward()
|
||||||
|
loss1.detach_()
|
||||||
|
else:
|
||||||
|
run_backward(outputs1, output_grads1)
|
||||||
|
|
||||||
|
return outputs0, loss0
|
||||||
|
|
||||||
|
|
||||||
|
def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(output, target).clone()
|
||||||
|
|
||||||
|
|
||||||
|
def ref_step(x, l, model, chunks):
|
||||||
|
ys, losses = [], []
|
||||||
|
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
|
||||||
|
micro_y = model(micro_x)
|
||||||
|
loss = criterion(micro_y, micro_l)
|
||||||
|
loss.backward()
|
||||||
|
ys.append(micro_y)
|
||||||
|
losses.append(loss)
|
||||||
|
y = torch.cat(ys, 0)
|
||||||
|
loss = torch.stack(losses)
|
||||||
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
|
def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
|
||||||
|
x, y = x.double(), y.double()
|
||||||
|
cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item()
|
||||||
|
return cos_diff
|
||||||
|
|
||||||
|
|
||||||
|
def main(rank, pp_size):
|
||||||
|
is_first_rank = rank == 0
|
||||||
|
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
torch.set_default_device(f"cuda:{rank}")
|
||||||
|
torch.manual_seed(233)
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
|
||||||
|
num_chunks = 20
|
||||||
|
micro_batch_size = 3
|
||||||
|
seq_len = 256
|
||||||
|
hidden_size = 512
|
||||||
|
if is_first_rank:
|
||||||
|
print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}", flush=True)
|
||||||
|
set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)])
|
||||||
|
set_p2p_tensor_dtype(torch.float32)
|
||||||
|
|
||||||
|
# Create a model and partition it for each process
|
||||||
|
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size * 2)])
|
||||||
|
|
||||||
|
# Full inputs
|
||||||
|
x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||||
|
l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
|
||||||
|
|
||||||
|
# Reference step
|
||||||
|
loss_ref, output_ref = ref_step(x, l, full_modules, num_chunks)
|
||||||
|
|
||||||
|
# DualPipeV
|
||||||
|
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size * 2 - 1 - rank])
|
||||||
|
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
|
||||||
|
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
|
||||||
|
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
|
||||||
|
dualpipev_model = DualPipeV(local_modules)
|
||||||
|
|
||||||
|
# DualPipeV inputs
|
||||||
|
if not is_first_rank:
|
||||||
|
x = None
|
||||||
|
l = None
|
||||||
|
|
||||||
|
# Training step
|
||||||
|
loss, outputs = dualpipev_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=False)
|
||||||
|
|
||||||
|
# Check loss
|
||||||
|
if is_first_rank:
|
||||||
|
assert torch.equal(loss, loss_ref)
|
||||||
|
else:
|
||||||
|
assert loss is None
|
||||||
|
assert outputs is None
|
||||||
|
|
||||||
|
# Check grads
|
||||||
|
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
|
||||||
|
assert cal_diff(p.grad, p_ref.grad) < 1e-13
|
||||||
|
dualpipev_model.zero_grad()
|
||||||
|
|
||||||
|
# Inference step
|
||||||
|
with torch.no_grad():
|
||||||
|
loss, outputs = dualpipev_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=True)
|
||||||
|
|
||||||
|
# Check loss and outputs
|
||||||
|
if is_first_rank:
|
||||||
|
assert torch.equal(loss, loss_ref)
|
||||||
|
assert torch.equal(outputs, output_ref)
|
||||||
|
else:
|
||||||
|
assert loss is None
|
||||||
|
assert outputs is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_dualpipev(ngpus):
|
||||||
|
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
for ngpus in range(num_gpus, 0, -1):
|
||||||
|
test_dualpipev(ngpus)
|
BIN
images/dualpipe.png
Normal file
BIN
images/dualpipe.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 787 KiB |
BIN
images/dualpipev.png
Normal file
BIN
images/dualpipev.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 523 KiB |
Binary file not shown.
Before Width: | Height: | Size: 80 KiB |
Loading…
Reference in New Issue
Block a user