mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-04-23 07:34:27 +00:00
441 lines
17 KiB
Python
441 lines
17 KiB
Python
from typing import Tuple, List, Union, Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
|
|
import dualpipe.comm as comm
|
|
from dualpipe.utils import WeightGradStore, run_backward, scatter, gather
|
|
|
|
|
|
class DualPipe(nn.Module):
|
|
def __init__(
|
|
self,
|
|
modules: Tuple[nn.Module, nn.Module],
|
|
batch_dim: int = 0,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
rank_mapping: Optional[List[int]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
|
|
self.module = nn.ModuleList(modules)
|
|
self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
|
|
self.batch_dim = batch_dim
|
|
self.group = process_group or dist.distributed_c10d._get_default_group()
|
|
self.num_ranks = self.group.size()
|
|
|
|
# rank_mapping: Map rank in process_group to actual pp rank.
|
|
# rank_inverse_mapping: Map actual pp rank to rank in process_group.
|
|
if rank_mapping is None:
|
|
rank_mapping = list(range(self.num_ranks))
|
|
rank_inverse_mapping = [None] * (self.num_ranks + 1)
|
|
for i in range(self.num_ranks):
|
|
rank_inverse_mapping[rank_mapping[i]] = i
|
|
|
|
self.rank = rank_mapping[self.group.rank()]
|
|
self.first_rank = rank_inverse_mapping[0]
|
|
self.prev_rank = rank_inverse_mapping[self.rank - 1]
|
|
self.next_rank = rank_inverse_mapping[self.rank + 1]
|
|
self.last_rank = rank_inverse_mapping[self.num_ranks - 1]
|
|
|
|
self.is_first_rank = self.rank == 0
|
|
self.is_last_rank = self.rank == self.num_ranks - 1
|
|
self.is_in_second_half = self.rank >= self.num_ranks // 2
|
|
self.is_middle_rank = (self.rank == self.num_ranks // 2 - 1) or (self.rank == self.num_ranks // 2)
|
|
|
|
def _reset_states(self) -> None:
|
|
WeightGradStore.clear()
|
|
|
|
self.input_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
|
self.output_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
|
self.input_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
|
self.output_grad_chunks: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = ([], [])
|
|
self.labels: Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]]] = None
|
|
self.loss_chunks: List[torch.Tensor] = []
|
|
self.criterion: Callable = None
|
|
|
|
self.current_f_chunk_id: List[int] = [0, 0]
|
|
self.current_b_chunk_id: List[int] = [0, 0]
|
|
self.current_send_f_chunk_id: List[int] = [0, 0]
|
|
self.current_send_b_chunk_id: List[int] = [0, 0]
|
|
self.current_recv_f_chunk_id: List[int] = [0, 0]
|
|
self.current_recv_b_chunk_id: List[int] = [0, 0]
|
|
self.comm_ops: List[dist.P2POp] = []
|
|
self.to_free: List[torch.Tensor] = []
|
|
|
|
def _forward_compute_chunk(self, phase: int) -> None:
|
|
phase ^= self.is_in_second_half
|
|
chunk_id = self.current_f_chunk_id[phase]
|
|
self.current_f_chunk_id[phase] += 1
|
|
inputs = self.input_chunks[phase][chunk_id]
|
|
if self.forward_only:
|
|
self.input_chunks[phase][chunk_id] = None
|
|
|
|
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
|
|
|
outputs = self.module[phase](*inputs)
|
|
outputs = [outputs] if isinstance(outputs, torch.Tensor) else outputs
|
|
if is_last_stage and self.criterion is not None:
|
|
labels = self.labels[phase][chunk_id]
|
|
loss = self.criterion(*outputs, *labels)
|
|
self.loss_chunks.append(loss)
|
|
|
|
if (not is_last_stage) or self.return_outputs:
|
|
self.output_chunks[phase].append(outputs)
|
|
|
|
def _backward_compute_chunk(self, phase: int, enable_zb: bool = False) -> None:
|
|
if self.forward_only:
|
|
return
|
|
|
|
phase ^= self.is_in_second_half
|
|
chunk_id = self.current_b_chunk_id[phase]
|
|
self.current_b_chunk_id[phase] += 1
|
|
|
|
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
|
|
|
WeightGradStore.enabled = enable_zb
|
|
if is_last_stage:
|
|
loss = self.loss_chunks[chunk_id]
|
|
loss.backward()
|
|
loss.detach_()
|
|
else:
|
|
outputs = self.output_chunks[phase][chunk_id]
|
|
if not self.return_outputs:
|
|
self.output_chunks[phase][chunk_id] = None
|
|
output_grads = self.output_grad_chunks[phase][chunk_id]
|
|
self.output_grad_chunks[phase][chunk_id] = None
|
|
non_empty = [(t, g) for t, g in zip(outputs, output_grads) if g is not None]
|
|
outputs, output_grads = list(zip(*non_empty))
|
|
if len(outputs) > 0:
|
|
run_backward(outputs, output_grads)
|
|
WeightGradStore.enabled = False
|
|
if enable_zb:
|
|
WeightGradStore.flush()
|
|
|
|
inputs = self.input_chunks[phase][chunk_id]
|
|
self.input_chunks[phase][chunk_id] = None
|
|
input_grads = [t.grad for t in inputs]
|
|
self.input_grad_chunks[phase].append(input_grads)
|
|
|
|
def _forward_backward_compute_chunk(self, phase0: int, phase1: int) -> None:
|
|
if self.forward_only:
|
|
self._forward_compute_chunk(phase0)
|
|
return
|
|
|
|
if not self.overlapped_forward_backward:
|
|
self._forward_compute_chunk(phase0)
|
|
self._backward_compute_chunk(phase1)
|
|
return
|
|
|
|
# pre-forward
|
|
phase0 ^= self.is_in_second_half
|
|
chunk_id0 = self.current_f_chunk_id[phase0]
|
|
self.current_f_chunk_id[phase0] += 1
|
|
module0 = self.module[phase0]
|
|
inputs0 = self.input_chunks[phase0][chunk_id0]
|
|
is_last_stage0 = (self.is_first_rank and phase0 == 1) or (self.is_last_rank and phase0 == 0)
|
|
|
|
if is_last_stage0 and self.criterion is not None:
|
|
labels0 = self.labels[phase0][chunk_id0]
|
|
criterion0 = self.criterion
|
|
else:
|
|
labels0 = []
|
|
criterion0 = None
|
|
|
|
# pre-backward
|
|
phase1 ^= self.is_in_second_half
|
|
chunk_id1 = self.current_b_chunk_id[phase1]
|
|
self.current_b_chunk_id[phase1] += 1
|
|
module1 = self.module[phase1]
|
|
is_last_stage1 = (self.is_first_rank and phase1 == 1) or (self.is_last_rank and phase1 == 0)
|
|
|
|
if is_last_stage1:
|
|
loss1 = self.loss_chunks[chunk_id1]
|
|
outputs1 = []
|
|
output_grads1 = []
|
|
else:
|
|
loss1 = None
|
|
outputs1 = self.output_chunks[phase1][chunk_id1]
|
|
if not self.return_outputs:
|
|
self.output_chunks[phase1][chunk_id1] = None
|
|
output_grads1 = self.output_grad_chunks[phase1][chunk_id1]
|
|
self.output_grad_chunks[phase1][chunk_id1] = None
|
|
non_empty = [(t, g) for t, g in zip(outputs1, output_grads1) if g is not None]
|
|
outputs1, output_grads1 = list(zip(*non_empty))
|
|
|
|
# forward & backward
|
|
outputs0, loss0 = type(module0).overlapped_forward_backward(
|
|
module0, inputs0, criterion0, labels0,
|
|
module1, loss1, outputs1, output_grads1,
|
|
)
|
|
|
|
# post-forward
|
|
if (not is_last_stage0) or self.return_outputs:
|
|
self.output_chunks[phase0].append(outputs0)
|
|
if is_last_stage0 and self.criterion is not None:
|
|
self.loss_chunks.append(loss0)
|
|
|
|
# post-backward
|
|
inputs = self.input_chunks[phase1][chunk_id1]
|
|
self.input_chunks[phase1][chunk_id1] = None
|
|
input_grads1 = [t.grad for t in inputs]
|
|
self.input_grad_chunks[phase1].append(input_grads1)
|
|
|
|
def _forward_chunk(self, phase: int, recv: bool = True, send: bool = True) -> None:
|
|
if recv:
|
|
self._recv_forward(phase)
|
|
self._commit_and_wait_comm()
|
|
|
|
self._forward_compute_chunk(phase)
|
|
|
|
if send:
|
|
self._send_forward(phase)
|
|
|
|
def _backward_chunk(self, phase: int, enable_zb: bool = False, recv: bool = True, send: bool = True) -> None:
|
|
if recv:
|
|
self._recv_backward(phase)
|
|
self._commit_and_wait_comm()
|
|
|
|
self._backward_compute_chunk(phase, enable_zb)
|
|
|
|
if send:
|
|
self._send_backward(phase)
|
|
|
|
def _forward_backward_chunk(self, phase0: int, phase1: int, recv0: bool = True) -> None:
|
|
if recv0:
|
|
self._recv_forward(phase0)
|
|
self._recv_backward(phase1)
|
|
self._commit_and_wait_comm()
|
|
|
|
self._forward_backward_compute_chunk(phase0, phase1)
|
|
|
|
self._send_forward(phase0)
|
|
self._send_backward(phase1)
|
|
|
|
def _weight_chunk(self) -> None:
|
|
if self.forward_only:
|
|
return
|
|
|
|
self._commit_and_wait_comm()
|
|
|
|
# Assume FIFO
|
|
WeightGradStore.pop()
|
|
|
|
def _free_tensors(self) -> None:
|
|
for tensor in self.to_free:
|
|
assert tensor._base is None, f"pipeline stage should not return view tensors {dist.get_rank(), tensor.shape}"
|
|
tensor.data = torch.Tensor()
|
|
self.to_free = []
|
|
|
|
def _recv_forward(self, phase: int) -> None:
|
|
phase ^= self.is_in_second_half
|
|
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
|
|
if is_first_stage:
|
|
return
|
|
|
|
self.current_recv_f_chunk_id[phase] += 1
|
|
tensors = comm.append_irecv(self.comm_ops, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
|
self.input_chunks[phase].append(tensors)
|
|
|
|
def _send_forward(self, phase: int) -> None:
|
|
phase ^= self.is_in_second_half
|
|
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
|
if is_last_stage:
|
|
return
|
|
|
|
chunk_id = self.current_send_f_chunk_id[phase]
|
|
self.current_send_f_chunk_id[phase] += 1
|
|
tensors = self.output_chunks[phase][chunk_id]
|
|
|
|
comm.append_isend(self.comm_ops, tensors, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
|
|
|
if not self.return_outputs:
|
|
self.to_free.extend(tensors)
|
|
|
|
def _recv_backward(self, phase: int) -> None:
|
|
if self.forward_only:
|
|
return
|
|
|
|
phase ^= self.is_in_second_half
|
|
is_last_stage = (self.is_first_rank and phase == 1) or (self.is_last_rank and phase == 0)
|
|
if is_last_stage:
|
|
return
|
|
|
|
self.current_recv_b_chunk_id[phase] += 1
|
|
tensors = comm.append_irecv(self.comm_ops, self.next_rank if phase == 0 else self.prev_rank, self.group)
|
|
self.output_grad_chunks[phase].append(tensors)
|
|
|
|
def _send_backward(self, phase: int) -> None:
|
|
if self.forward_only:
|
|
return
|
|
|
|
phase ^= self.is_in_second_half
|
|
is_first_stage = (self.is_first_rank and phase == 0) or (self.is_last_rank and phase == 1)
|
|
if is_first_stage:
|
|
return
|
|
|
|
chunk_id = self.current_send_b_chunk_id[phase]
|
|
self.current_send_b_chunk_id[phase] += 1
|
|
tensors = self.input_grad_chunks[phase][chunk_id]
|
|
self.input_grad_chunks[phase][chunk_id] = None
|
|
|
|
comm.append_isend(self.comm_ops, tensors, self.prev_rank if phase == 0 else self.next_rank, self.group)
|
|
|
|
def _commit_and_wait_comm(self) -> None:
|
|
if not self.comm_ops:
|
|
return
|
|
reqs = dist.batch_isend_irecv(self.comm_ops)
|
|
for req in reqs:
|
|
req.wait()
|
|
self.comm_ops = []
|
|
self._free_tensors()
|
|
|
|
def step(
|
|
self,
|
|
*inputs: Optional[torch.Tensor],
|
|
num_chunks: int = 0,
|
|
criterion: Optional[Callable] = None,
|
|
labels: List[Optional[torch.Tensor]] = [],
|
|
return_outputs: bool = False,
|
|
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
|
|
"""
|
|
Execute a training or inference step.
|
|
|
|
Arguments:
|
|
*inputs: Module inputs. Required only on the first/last ranks.
|
|
num_chunks: The number of micro-batches.
|
|
criterion: Loss function, invoked as ``criterion(*outputs, *labels)``. Required only on the first/last ranks.
|
|
labels: Labels of the loss function. Required only on the first/last ranks.
|
|
labels on the first rank corresponds to inputs on the last rank.
|
|
labels on the last rank corresponds to inputs on the first rank.
|
|
return_outputs: Whether to return outputs on the first/last ranks. Default: ``False``.
|
|
|
|
Returns: (loss, outputs)
|
|
loss: Loss for the batch.
|
|
loss on the first rank corresponds to inputs on the last rank.
|
|
loss on the last rank corresponds to inputs on the first rank.
|
|
Otherwise: ``None``.
|
|
outputs: Returned only if ``return_outputs=True``.
|
|
outputs on the first rank corresponds to inputs on the last rank.
|
|
outputs on the last rank corresponds to inputs on the first rank.
|
|
Otherwise: ``None``.
|
|
|
|
"""
|
|
assert comm.TENSOR_SHAPES is not None and comm.TENSOR_DTYPE is not None, \
|
|
"You need to call set_p2p_tensor_shapes and set_p2p_tensor_dtype before doing a step."
|
|
self.forward_only = not torch.is_grad_enabled()
|
|
self.return_outputs = return_outputs
|
|
|
|
rank = self.rank
|
|
num_ranks = self.num_ranks
|
|
assert num_ranks % 2 == 0
|
|
assert num_chunks > 0 and num_chunks % 2 == 0 and num_chunks >= num_ranks * 2, f"{num_chunks=}, {num_ranks=}"
|
|
num_half_ranks = num_ranks // 2
|
|
half_rank = min(rank, num_ranks - 1 - rank)
|
|
half_num_chunks = num_chunks // 2
|
|
self.num_half_ranks = num_half_ranks
|
|
self.half_rank = half_rank
|
|
|
|
if not self.forward_only and (self.is_first_rank or self.is_last_rank):
|
|
assert criterion is not None
|
|
|
|
self._reset_states()
|
|
|
|
inputs = scatter(inputs, half_num_chunks, self.batch_dim)
|
|
labels = scatter(labels, half_num_chunks, self.batch_dim)
|
|
if self.is_first_rank:
|
|
self.input_chunks = (inputs, [])
|
|
self.labels = ([], labels)
|
|
elif self.is_last_rank:
|
|
self.input_chunks = ([], inputs)
|
|
self.labels = (labels, [])
|
|
self.criterion = criterion
|
|
|
|
# For the first half of the ranks: phase 0 means forward direction, phase 1 means reverse direction.
|
|
# For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction.
|
|
|
|
# Step 1: nF0
|
|
step_1 = (num_half_ranks - half_rank - 1) * 2
|
|
for i in range(step_1):
|
|
self._forward_chunk(0)
|
|
|
|
# Step 2: nF0F1
|
|
step_2 = half_rank + 1
|
|
self._recv_forward(0)
|
|
for i in range(step_2):
|
|
self._forward_chunk(0, recv=False, send=self.is_middle_rank)
|
|
self._recv_forward(0)
|
|
self._forward_chunk(1, send=(not self.is_middle_rank) or (i < step_2 - 1))
|
|
if not self.is_middle_rank:
|
|
self._send_forward(0)
|
|
|
|
# Step 3: nB1W1F1 (Use zero bubble)
|
|
step_3 = num_half_ranks - half_rank - 1
|
|
for i in range(step_3):
|
|
self._backward_chunk(1, enable_zb=True)
|
|
self._recv_forward(1)
|
|
self._weight_chunk()
|
|
self._forward_chunk(1, recv=False)
|
|
|
|
# Step 4 (Main step): nF0B1F1B0
|
|
step_4 = half_num_chunks - num_ranks + half_rank + 1
|
|
for i in range(step_4):
|
|
if i == 0:
|
|
if self.is_middle_rank:
|
|
# NOTE: We don't overlap these two chunks to further reduce bubble size.
|
|
self._forward_chunk(0, recv=False, send=False)
|
|
self._send_forward(1)
|
|
self._backward_chunk(1, send=False)
|
|
self._send_forward(0)
|
|
self._send_backward(1)
|
|
else:
|
|
self._forward_backward_chunk(0, 1, recv0=False)
|
|
else:
|
|
self._forward_backward_chunk(0, 1)
|
|
self._forward_backward_chunk(1, 0)
|
|
|
|
# Step 5: nB1F1B0
|
|
step_5 = num_half_ranks - half_rank - 1
|
|
for i in range(step_5):
|
|
self._backward_chunk(1)
|
|
self._forward_backward_chunk(1, 0)
|
|
|
|
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
|
|
step_6 = half_rank + 1
|
|
enable_zb = False
|
|
for i in range(step_6):
|
|
if i == step_6 // 2 and half_rank % 2 == 1:
|
|
enable_zb = True
|
|
self._backward_chunk(1, enable_zb=enable_zb)
|
|
if i == step_6 // 2 and half_rank % 2 == 0:
|
|
enable_zb = True
|
|
self._backward_chunk(0, enable_zb=enable_zb)
|
|
|
|
# Step 7: nWB0 (Use zero bubble)
|
|
step_7 = num_half_ranks - half_rank - 1
|
|
for i in range(step_7):
|
|
self._weight_chunk()
|
|
self._backward_chunk(0, enable_zb=True)
|
|
|
|
# Step 8: nW
|
|
step_8 = half_rank + 1
|
|
for i in range(step_8):
|
|
self._weight_chunk()
|
|
assert WeightGradStore.funcs_queue.empty()
|
|
|
|
self._commit_and_wait_comm()
|
|
|
|
loss, outputs = None, None
|
|
if self.is_first_rank or self.is_last_rank:
|
|
if criterion is not None:
|
|
loss = torch.stack(self.loss_chunks)
|
|
if return_outputs:
|
|
outputs = gather(self.output_chunks[self.is_first_rank], self.batch_dim)
|
|
if len(outputs) == 1:
|
|
outputs = outputs[0]
|
|
|
|
self._reset_states()
|
|
|
|
return loss, outputs
|