fixed tensorboard update images by adding batch dimension (#9)

This commit is contained in:
JonathonLuiten 2023-07-10 18:45:55 -04:00 committed by GitHub
parent b89771f0ba
commit 9490ef7612
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -162,9 +162,9 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
images = torch.cat((images, image.unsqueeze(0)), dim=0) images = torch.cat((images, image.unsqueeze(0)), dim=0)
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
if tb_writer and (idx < 5): if tb_writer and (idx < 5):
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration) tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
if iteration == testing_iterations[0]: if iteration == testing_iterations[0]:
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration) tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
l1_test = l1_loss(images, gts) l1_test = l1_loss(images, gts)
psnr_test = psnr(images, gts).mean() psnr_test = psnr(images, gts).mean()
@ -204,4 +204,4 @@ if __name__ == "__main__":
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
# All done # All done
print("\nTraining complete.") print("\nTraining complete.")