fix: correct typos

This commit is contained in:
dotrail 2025-02-27 11:36:53 +00:00
parent ebe10fcefe
commit 035a38fa24

View File

@ -20,7 +20,7 @@ class DualPipe(nn.Module):
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device()) assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
self.module = nn.ModuleList(modules) self.module = nn.ModuleList(modules)
self.overlaped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlaped_forward_backward") self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
self.batch_dim = batch_dim self.batch_dim = batch_dim
self.group = process_group or dist.distributed_c10d._get_default_group() self.group = process_group or dist.distributed_c10d._get_default_group()
self.num_ranks = self.group.size() self.num_ranks = self.group.size()
@ -123,7 +123,7 @@ class DualPipe(nn.Module):
self._forward_compute_chunk(phase0) self._forward_compute_chunk(phase0)
return return
if not self.overlaped_forward_backward: if not self.overlapped_forward_backward:
self._forward_compute_chunk(phase0) self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1) self._backward_compute_chunk(phase1)
return return
@ -165,7 +165,7 @@ class DualPipe(nn.Module):
outputs1, output_grads1 = list(zip(*non_empty)) outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward # forward & backward
outputs0, loss0 = type(module0).overlaped_forward_backward( outputs0, loss0 = type(module0).overlapped_forward_backward(
module0, inputs0, criterion0, labels0, module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1, module1, loss1, outputs1, output_grads1,
) )
@ -300,7 +300,7 @@ class DualPipe(nn.Module):
return_outputs: bool = False, return_outputs: bool = False,
) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]: ) -> Tuple[Optional[torch.Tensor], Optional[Union[torch.Tensor, Tuple[torch.Tensor]]]]:
""" """
Execute a traning or inference step. Execute a training or inference step.
Arguments: Arguments:
*inputs: Module inputs. Required only on the first/last ranks. *inputs: Module inputs. Required only on the first/last ranks.
@ -352,7 +352,7 @@ class DualPipe(nn.Module):
self.labels = (labels, []) self.labels = (labels, [])
self.criterion = criterion self.criterion = criterion
# For the fisrt half of the ranks: phase 0 means forward direction, phase 1 means reverse direction. # For the first half of the ranks: phase 0 means forward direction, phase 1 means reverse direction.
# For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction. # For the second half of the ranks: phase 0 means reverse direction, phase 1 means forward direction.
# Step 1: nF0 # Step 1: nF0