Return a scalar via mean

Fix loss dimension issue in ref_step function.
Aggregate into a scalar
This commit is contained in:
A-transformer 2025-03-06 21:21:37 +04:00 committed by GitHub
parent f371022947
commit 7cb1e5e632
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,14 +89,15 @@ def criterion(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def ref_step(x, l, model, chunks):
ys, losses = [], []
model.zero_grad()
for micro_x, micro_l in zip(x.chunk(chunks), l.chunk(chunks)):
micro_y = model(micro_x)
loss = criterion(micro_y, micro_l)
loss.backward()
ys.append(micro_y)
losses.append(loss)
ys.append(micro_y.detach())
losses.append(loss.detach())
y = torch.cat(ys, 0)
loss = torch.stack(losses)
loss = torch.stack(losses).mean()
return loss, y