Return a scalar via mean

Fix loss dimension issue in ref_step function.
Aggregate into a scalar
This commit is contained in:
A-transformer
2025-03-06 21:21:37 +04:00
committed by GitHub
parent f371022947
commit 7cb1e5e632

View File

@@ -89,14 +89,15 @@ def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def ref_step(x, l, model, chunks):
ys, losses = [], []
model.zero_grad()
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
micro_y = model(micro_x)
loss = criterion(micro_y, micro_l)
loss.backward()
ys.append(micro_y)
losses.append(loss)
ys.append(micro_y.detach())
losses.append(loss.detach())
y = torch.cat(ys, 0)
loss = torch.stack(losses)
loss = torch.stack(losses).mean()
return loss, y