add dualpipev

This commit is contained in:
ljss 2025-03-04 17:50:00 +08:00
parent 036c14e7f6
commit 8ec5883b30
8 changed files with 609 additions and 6 deletions

View File

@ -4,7 +4,7 @@ DualPipe is an innovative bidirectional pipeline parallelism algorithm introduce
### Schedules
![schedules](images/schedules.png)
![dualpipe](images/dualpipe.png)
Example DualPipe scheduling for 8 PP ranks and 20 micro-batches in two directions.
The micro-batches in the reverse direction are symmetric to those in the forward direction, so
@ -23,12 +23,21 @@ have mutually overlapped computation and communication
full backward chunk, 𝑊 denotes the execution time of a "backward for weights" chunk, and 𝐹&𝐵
denotes the execution time of two mutually overlapped forward and backward chunks.
# DualPipeV
DualPipeV is a concise V-shape schedule derived from DualPipe using a "cut-in-half" procedure, introduced by Sea AI Lab as "Cut-in-half" in their [blog post](https://hackmd.io/@ufotalent/r1lVXsa9Jg). Thanks to them for this efficient schedule!
![dualpipev](images/dualpipev.png)
Example DualPipeV scheduling for 4 PP ranks and 10 micro-batches.
## Quick Start
The usage is shown in the following example:
```bash
python example.py
python examples/example_dualpipe.py
python examples/example_dualpipev.py
```
Note: For real-world applications, you will need to implement a custom `overlapped_forward_backward` method tailored to your specific module.

View File

@ -1,16 +1,16 @@
__version__ = "1.0.0"
from dualpipe.dualpipe import (
DualPipe,
WeightGradStore,
)
from dualpipe.dualpipe import DualPipe
from dualpipe.dualpipev import DualPipeV
from dualpipe.comm import (
set_p2p_tensor_shapes,
set_p2p_tensor_dtype,
)
from dualpipe.utils import WeightGradStore
__all__ = [
DualPipe,
DualPipeV,
WeightGradStore,
set_p2p_tensor_shapes,
set_p2p_tensor_dtype,

411
dualpipe/dualpipev.py Normal file
View File

@ -0,0 +1,411 @@
from typing import Tuple, List, Union, Callable, Optional
import torch
import torch.nn as nn
import torch.distributed as dist
import dualpipe.comm as comm
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
class DualPipeV(nn.Module):
def __init__(
self,
modules: Tuple[nn.Module, nn.Module],
batch_dim: int = 0,
process_group: Optional[dist.ProcessGroup] = None,
rank_mapping: Optional[List[int]] = None,
) -> None:
super().__init__()
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
self.module = nn.ModuleList(modules)
self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
self.batch_dim = batch_dim
self.group = process_group or dist.distributed_c10d._get_default_group()
self.num_ranks = self.group.size()
# rank_mapping: Map rank in process_group to actual pp rank.
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
if rank_mapping is None:
rank_mapping = list(range(self.num_ranks))
rank_inverse_mapping = [None] * (self.num_ranks + 1)
for i in range(self.num_ranks):
rank_inverse_mapping[rank_mapping[i]] = i
self.rank = rank_mapping[self.group.rank()]
self.prev_rank = rank_inverse_mapping[self.rank - 1]
self.next_rank = rank_inverse_mapping[self.rank + 1]
self.is_first_rank = self.rank == 0
self.is_last_rank = self.rank == self.num_ranks - 1
def _reset_states(self) -> None:
WeightGradStore.clear()
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
self.labels: List[List[torch.Tensor]] = None
self.loss_chunks: List[torch.Tensor] = []
self.criterion: Callable = None
self.current_f_chunk_id: List[int] = [0, 0]
self.current_b_chunk_id: List[int] = [0, 0]
self.current_send_f_chunk_id: List[int] = [0, 0]
self.current_send_b_chunk_id: List[int] = [0, 0]
self.current_recv_f_chunk_id: List[int] = [0, 0]
self.current_recv_b_chunk_id: List[int] = [0, 0]
self.comm_ops: List[dist.P2POp] = []
self.to_free: List[torch.Tensor] = []
def _forward_compute_chunk(self, phase: int) -> None:
chunk_id = self.current_f_chunk_id[phase]
self.current_f_chunk_id[phase] += 1
inputs = self.input_chunks[phase][chunk_id]
if self.forward_only:
self.input_chunks[phase][chunk_id] = None
is_last_stage = (self.is_first_rank and phase == 1)
outputs = self.module[phase](*inputs)
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
if is_last_stage and self.criterion is not None:
labels = self.labels[chunk_id]
loss = self.criterion(*outputs, *labels)
self.loss_chunks.append(loss)
if self.is_last_rank and phase == 0:
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs])
if (not is_last_stage) or self.return_outputs:
self.output_chunks[phase].append(outputs)
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
if self.forward_only:
return
chunk_id = self.current_b_chunk_id[phase]
self.current_b_chunk_id[phase] += 1
is_last_stage = (self.is_first_rank and phase == 1)
WeightGradStore.enabled = enable_zb
if is_last_stage:
loss = self.loss_chunks[chunk_id]
loss.backward()
loss.detach_()
else:
outputs = self.output_chunks[phase][chunk_id]
if not self.return_outputs:
self.output_chunks[phase][chunk_id] = None
output_grads = self.output_grad_chunks[phase][chunk_id]
self.output_grad_chunks[phase][chunk_id] = None
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
outputs, output_grads = list(zip(*non_empty))
if len(outputs) > 0:
run_backward(outputs, output_grads)
WeightGradStore.enabled = False
if enable_zb:
WeightGradStore.flush()
inputs = self.input_chunks[phase][chunk_id]
self.input_chunks[phase][chunk_id] = None
input_grads = [t.grad for t in inputs]
if self.is_last_rank and phase == 1:
self.output_grad_chunks[0].append(input_grads)
else:
self.input_grad_chunks[phase].append(input_grads)
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
if self.forward_only:
self._forward_compute_chunk(phase0)
return
if not self.overlapped_forward_backward:
self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1)
return
# pre-forward
chunk_id0 = self.current_f_chunk_id[phase0]
self.current_f_chunk_id[phase0] += 1
module0 = self.module[phase0]
inputs0 = self.input_chunks[phase0][chunk_id0]
is_last_stage0 = (self.is_first_rank and phase0 == 1)
if is_last_stage0 and self.criterion is not None:
labels0 = self.labels[chunk_id0]
criterion0 = self.criterion
else:
labels0 = []
criterion0 = None
# pre-backward
chunk_id1 = self.current_b_chunk_id[phase1]
self.current_b_chunk_id[phase1] += 1
module1 = self.module[phase1]
is_last_stage1 = (self.is_first_rank and phase1 == 1)
if is_last_stage1:
loss1 = self.loss_chunks[chunk_id1]
outputs1 = []
output_grads1 = []
else:
loss1 = None
outputs1 = self.output_chunks[phase1][chunk_id1]
if not self.return_outputs:
self.output_chunks[phase1][chunk_id1] = None
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
self.output_grad_chunks[phase1][chunk_id1] = None
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward
outputs0, loss0 = type(module0).overlapped_forward_backward(
module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1,
)
# post-forward
if self.is_last_rank and phase0 == 0:
self.input_chunks[1].append([output.detach().requires_grad_() for output in outputs0])
if (not is_last_stage0) or self.return_outputs:
self.output_chunks[phase0].append(outputs0)
if is_last_stage0 and self.criterion is not None:
self.loss_chunks.append(loss0)
# post-backward
inputs = self.input_chunks[phase1][chunk_id1]
self.input_chunks[phase1][chunk_id1] = None
input_grads1 = [t.grad for t in inputs]
if self.is_last_rank and phase1 == 1:
self.output_grad_chunks[0].append(input_grads1)
else:
self.input_grad_chunks[phase1].append(input_grads1)
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_forward(phase)
self._commit_and_wait_comm()
self._forward_compute_chunk(phase)
if send:
self._send_forward(phase)
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
if recv:
self._recv_backward(phase)
self._commit_and_wait_comm()
self._backward_compute_chunk(phase, enable_zb)
if send:
self._send_backward(phase)
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
if recv0:
self._recv_forward(phase0)
self._recv_backward(phase1)
self._commit_and_wait_comm()
self._forward_backward_compute_chunk(phase0, phase1)
self._send_forward(phase0)
self._send_backward(phase1)
def _weight_chunk(self) -> None:
if self.forward_only:
return
self._commit_and_wait_comm()
# Assume FIFO
WeightGradStore.pop()
def _free_tensors(self) -> None:
for tensor in self.to_free:
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
tensor.data = torch.Tensor()
self.to_free = []
def _recv_forward(self, phase: int) -> None:
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
return
self.current_recv_f_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
self.input_chunks[phase].append(tensors)
def _send_forward(self, phase: int) -> None:
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
return
chunk_id = self.current_send_f_chunk_id[phase]
self.current_send_f_chunk_id[phase] += 1
tensors = self.output_chunks[phase][chunk_id]
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
if not self.return_outputs:
self.to_free.extend(tensors)
def _recv_backward(self, phase: int) -> None:
if self.forward_only:
return
if (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0):
return
self.current_recv_b_chunk_id[phase] += 1
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
self.output_grad_chunks[phase].append(tensors)
def _send_backward(self, phase: int) -> None:
if self.forward_only:
return
if (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1):
return
chunk_id = self.current_send_b_chunk_id[phase]
self.current_send_b_chunk_id[phase] += 1
tensors = self.input_grad_chunks[phase][chunk_id]
self.input_grad_chunks[phase][chunk_id] = None
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
def _commit_and_wait_comm(self) -> None:
if not self.comm_ops:
return
reqs = dist.batch_isend_irecv(self.comm_ops)
for req in reqs:
req.wait()
self.comm_ops = []
self._free_tensors()
def step(
self,
*inputs: Optional[torch.Tensor],
num_chunks: int = 0,
criterion: Optional[Callable] = None,
labels: List[Optional[torch.Tensor]] = [],
return_outputs: bool = False,
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
"""
Execute a training or inference step.
Arguments:
*inputs: Module inputs. Required only on the first rank.
num_chunks: The number of micro-batches.
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first rank.
labels: Labels of the loss function. Required only on the first rank.
return_outputs: Whether to return outputs on the first rank. Default: ``False``.
Returns: (loss, outputs)
loss: Loss for the batch. Returned only on the first rank.
outputs: Module outputs. Returned only if ``return_outputs=True`` and on the first rank.
"""
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before executing a step."
self.forward_only = not torch.is_grad_enabled()
self.return_outputs = return_outputs
rank = self.rank
num_ranks = self.num_ranks
assert num_chunks > 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
if not self.forward_only and self.is_first_rank:
assert criterion is not None
self._reset_states()
if self.is_first_rank:
self.input_chunks = (scatter(inputs, num_chunks, self.batch_dim), [])
self.labels = scatter(labels, num_chunks, self.batch_dim)
self.criterion = criterion
# Step 1: nF0
step_1 = (num_ranks - rank - 1) * 2
for i in range(step_1):
self._forward_chunk(0)
# Step 2: nF0F1
step_2 = rank + 1
self._recv_forward(0)
for i in range(step_2):
self._forward_chunk(0, recv=False, send=False)
self._recv_forward(0)
self._forward_chunk(1, send=(not self.is_last_rank) or (i < step_2 - 1))
self._send_forward(0)
# Step 3: nB1W1F1 (Use zero bubble)
step_3 = num_ranks - rank - 1
for i in range(step_3):
self._backward_chunk(1, enable_zb=True)
self._recv_forward(1)
self._weight_chunk()
self._forward_chunk(1, recv=False)
# Step 4 (Main step): nF0B1F1B0
step_4 = num_chunks - num_ranks * 2 + rank + 1
for i in range(step_4):
if i == 0:
if self.is_last_rank:
# NOTE: We don't overlap these two chunks to further reduce bubble size.
self._forward_chunk(0, recv=False, send=False)
self._send_forward(1)
self._backward_chunk(1, send=False)
self._send_forward(0)
self._send_backward(1)
else:
self._forward_backward_chunk(0, 1, recv0=False)
else:
self._forward_backward_chunk(0, 1)
self._forward_backward_chunk(1, 0)
# Step 5: nB1F1B0
step_5 = num_ranks - rank - 1
for i in range(step_5):
self._backward_chunk(1)
self._forward_backward_chunk(1, 0)
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
step_6 = rank + 1
enable_zb = False
for i in range(step_6):
if i == step_6 // 2 and rank % 2 == 1:
enable_zb = True
self._backward_chunk(1, enable_zb=enable_zb)
if i == step_6 // 2 and rank % 2 == 0:
enable_zb = True
self._backward_chunk(0, enable_zb=enable_zb)
# Step 7: nWB0 (Use zero bubble)
step_7 = num_ranks - rank - 1
for i in range(step_7):
self._weight_chunk()
self._backward_chunk(0, enable_zb=True)
# Step 8: nW
step_8 = rank + 1
for i in range(step_8):
self._weight_chunk()
assert WeightGradStore.funcs_queue.empty()
self._commit_and_wait_comm()
loss, outputs = None, None
if self.is_first_rank:
if criterion is not None:
loss = torch.stack(self.loss_chunks)
if return_outputs:
outputs = gather(self.output_chunks[1], self.batch_dim)
if len(outputs) == 1:
outputs = outputs[0]
self._reset_states()
return loss, outputs

View File

@ -0,0 +1,183 @@
from typing import List, Optional, Callable, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from dualpipe import DualPipeV, set_p2p_tensor_shapes, set_p2p_tensor_dtype
from dualpipe.utils import WeightGradStore, run_backward
class LinearFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
output = F.linear(input, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
if weight.grad is None:
weight.grad = torch.zeros_like(weight)
def grad_weight_fn():
weight.grad += grad_output.flatten(0, -2).T @ input.flatten(0, -2)
if WeightGradStore.enabled:
WeightGradStore.put(grad_weight_fn)
else:
grad_weight_fn()
grad_input = grad_output @ weight
return grad_input, None
class MyLinear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return LinearFunc.apply(input, self.weight)
class PipelineStage(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.linear1 = MyLinear(hidden_size, hidden_size * 4, bias=False)
self.linear2 = MyLinear(hidden_size * 4, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = F.gelu(x)
x = self.linear2(x)
return x
@classmethod
def overlapped_forward_backward(
cls,
module0: "PipelineStage",
inputs0: List[torch.Tensor],
criterion0: Optional[Callable],
labels0: Optional[List[torch.Tensor]],
module1: "PipelineStage",
loss1: Optional[torch.Tensor],
outputs1: Optional[List[torch.Tensor]],
output_grads1: Optional[List[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
You should implement custom forward-backward overlap strategy.
The code below is just an example.
"""
outputs0 = module0(*inputs0)
outputs0 = [outputs0] if isinstance(outputs0, torch.Tensor) else outputs0
if criterion0 is not None:
loss0 = criterion0(*outputs0, *labels0)
else:
loss0 = None
if loss1 is not None:
loss1.backward()
loss1.detach_()
else:
run_backward(outputs1, output_grads1)
return outputs0, loss0
def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(output, target).clone()
def ref_step(x, l, model, chunks):
ys, losses = [], []
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
micro_y = model(micro_x)
loss = criterion(micro_y, micro_l)
loss.backward()
ys.append(micro_y)
losses.append(loss)
y = torch.cat(ys, 0)
loss = torch.stack(losses)
return loss, y
def cal_diff(x: torch.Tensor, y: torch.Tensor) -> float:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / (x * x + y * y).sum().item()
return cos_diff
def main(rank, pp_size):
is_first_rank = rank == 0
dist.init_process_group(backend='nccl', init_method="env://", world_size=pp_size, rank=rank)
torch.cuda.set_device(rank)
torch.set_default_device(f"cuda:{rank}")
torch.manual_seed(233)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
num_chunks = 20
micro_batch_size = 3
seq_len = 256
hidden_size = 512
if is_first_rank:
print(f"{pp_size=}, {num_chunks=}, {seq_len=}, {hidden_size=}", flush=True)
set_p2p_tensor_shapes([(micro_batch_size, seq_len, hidden_size)])
set_p2p_tensor_dtype(torch.float32)
# Create a model and partition it for each process
full_modules = nn.Sequential(*[PipelineStage(hidden_size) for _ in range(pp_size * 2)])
# Full inputs
x = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
l = torch.randn(num_chunks * micro_batch_size, seq_len, hidden_size)
# Reference step
loss_ref, output_ref = ref_step(x, l, full_modules, num_chunks)
# DualPipeV
local_full_modules = nn.Sequential(full_modules[rank], full_modules[pp_size * 2 - 1 - rank])
local_modules = nn.Sequential(PipelineStage(hidden_size), PipelineStage(hidden_size))
local_modules[0].load_state_dict(local_full_modules[0].state_dict())
local_modules[1].load_state_dict(local_full_modules[1].state_dict())
dualpipev_model = DualPipeV(local_modules)
# DualPipeV inputs
if not is_first_rank:
x = None
l = None
# Training step
loss, outputs = dualpipev_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=False)
# Check loss
if is_first_rank:
assert torch.equal(loss, loss_ref)
else:
assert loss is None
assert outputs is None
# Check grads
for ((n, p), p_ref) in zip(local_modules.named_parameters(), local_full_modules.parameters()):
assert cal_diff(p.grad, p_ref.grad) < 1e-13
dualpipev_model.zero_grad()
# Inference step
with torch.no_grad():
loss, outputs = dualpipev_model.step(x, num_chunks=num_chunks, criterion=criterion, labels=(l,), return_outputs=True)
# Check loss and outputs
if is_first_rank:
assert torch.equal(loss, loss_ref)
assert torch.equal(outputs, output_ref)
else:
assert loss is None
assert outputs is None
def test_dualpipev(ngpus):
torch.multiprocessing.spawn(main, args=(ngpus, ), nprocs=ngpus, daemon=True)
if __name__ == "__main__":
num_gpus = torch.cuda.device_count()
for ngpus in range(num_gpus, 0, -1):
test_dualpipev(ngpus)

BIN
images/dualpipe.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 787 KiB

BIN
images/dualpipev.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 523 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 80 KiB