mirror of
https://github.com/deepseek-ai/DualPipe
synced 2025-04-26 17:11:46 +00:00
Return a scalar via mean
Fix loss dimension issue in ref_step function. Aggregate into a scalar
This commit is contained in:
parent
f371022947
commit
7cb1e5e632
@ -89,14 +89,15 @@ def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def ref_step(x, l, model, chunks):
|
def ref_step(x, l, model, chunks):
|
||||||
ys, losses = [], []
|
ys, losses = [], []
|
||||||
|
model.zero_grad()
|
||||||
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
|
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
|
||||||
micro_y = model(micro_x)
|
micro_y = model(micro_x)
|
||||||
loss = criterion(micro_y, micro_l)
|
loss = criterion(micro_y, micro_l)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
ys.append(micro_y)
|
ys.append(micro_y.detach())
|
||||||
losses.append(loss)
|
losses.append(loss.detach())
|
||||||
y = torch.cat(ys, 0)
|
y = torch.cat(ys, 0)
|
||||||
loss = torch.stack(losses)
|
loss = torch.stack(losses).mean()
|
||||||
return loss, y
|
return loss, y
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user