From 9490ef76125c3fbfee362039b816702f3241735d Mon Sep 17 00:00:00 2001 From: JonathonLuiten Date: Mon, 10 Jul 2023 18:45:55 -0400 Subject: [PATCH] fixed tensorboard update images by adding batch dimension (#9) --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 6192c8b..f919db2 100644 --- a/train.py +++ b/train.py @@ -162,9 +162,9 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i images = torch.cat((images, image.unsqueeze(0)), dim=0) gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0) if tb_writer and (idx < 5): - tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) if iteration == testing_iterations[0]: - tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration) + tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration) l1_test = l1_loss(images, gts) psnr_test = psnr(images, gts).mean() @@ -204,4 +204,4 @@ if __name__ == "__main__": training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations) # All done - print("\nTraining complete.") \ No newline at end of file + print("\nTraining complete.")