Fix typo: overlaped_forward_backward -> overlapped_forward_backward

This commit is contained in:
RK 2025-02-27 12:42:25 +08:00
parent ebe10fcefe
commit bb214e82b5
2 changed files with 4 additions and 4 deletions

View File

@ -20,7 +20,7 @@ class DualPipe(nn.Module):
assert next(modules[0].parameters()).device == torch.device(torch.cuda.current_device())
self.module = nn.ModuleList(modules)
self.overlaped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlaped_forward_backward")
self.overlapped_forward_backward = type(modules[0]) == type(modules[1]) and hasattr(type(modules[0]), "overlapped_forward_backward")
self.batch_dim = batch_dim
self.group = process_group or dist.distributed_c10d._get_default_group()
self.num_ranks = self.group.size()
@ -123,7 +123,7 @@ class DualPipe(nn.Module):
self._forward_compute_chunk(phase0)
return
if not self.overlaped_forward_backward:
if not self.overlapped_forward_backward:
self._forward_compute_chunk(phase0)
self._backward_compute_chunk(phase1)
return
@ -165,7 +165,7 @@ class DualPipe(nn.Module):
outputs1, output_grads1 = list(zip(*non_empty))
# forward & backward
outputs0, loss0 = type(module0).overlaped_forward_backward(
outputs0, loss0 = type(module0).overlapped_forward_backward(
module0, inputs0, criterion0, labels0,
module1, loss1, outputs1, output_grads1,
)

View File

@ -52,7 +52,7 @@ class PipelineStage(nn.Module):
return x
@classmethod
def overlaped_forward_backward(
def overlapped_forward_backward(
cls,
module0: "PipelineStage",
inputs0: List[torch.Tensor],