mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-24 21:13:46 +00:00
fixed tensorboard update images by adding batch dimension (#9)
This commit is contained in:
parent
b89771f0ba
commit
9490ef7612
6
train.py
6
train.py
@ -162,9 +162,9 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
|
|||||||
images = torch.cat((images, image.unsqueeze(0)), dim=0)
|
images = torch.cat((images, image.unsqueeze(0)), dim=0)
|
||||||
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
|
gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
|
||||||
if tb_writer and (idx < 5):
|
if tb_writer and (idx < 5):
|
||||||
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration)
|
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
|
||||||
if iteration == testing_iterations[0]:
|
if iteration == testing_iterations[0]:
|
||||||
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration)
|
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
|
||||||
|
|
||||||
l1_test = l1_loss(images, gts)
|
l1_test = l1_loss(images, gts)
|
||||||
psnr_test = psnr(images, gts).mean()
|
psnr_test = psnr(images, gts).mean()
|
||||||
@ -204,4 +204,4 @@ if __name__ == "__main__":
|
|||||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
|
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
|
||||||
|
|
||||||
# All done
|
# All done
|
||||||
print("\nTraining complete.")
|
print("\nTraining complete.")
|
||||||
|
Loading…
Reference in New Issue
Block a user