diff --git a/submodules/diff-gaussian-rasterization b/submodules/diff-gaussian-rasterization index 73917be..8064f52 160000 --- a/submodules/diff-gaussian-rasterization +++ b/submodules/diff-gaussian-rasterization @@ -1 +1 @@ -Subproject commit 73917be7cfd1694e1e61adfa803981925494be5e +Subproject commit 8064f52ca233942bdec2d1a1451c026deedd320b diff --git a/train.py b/train.py index 8e56202..0b7a9c1 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ try: except ImportError: TENSORBOARD_FOUND = False -def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint): +def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): first_iter = 0 tb_writer = prepare_output_and_logger(dataset) gaussians = GaussianModel(dataset.sh_degree) @@ -78,6 +78,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) # Render + if (iteration - 1) == debug_from: + pipe.debug = True render_pkg = render(viewpoint_cam, gaussians, pipe, background) image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] @@ -193,6 +195,7 @@ if __name__ == "__main__": pp = PipelineParams(parser) parser.add_argument('--ip', type=str, default="127.0.0.1") parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) parser.add_argument('--detect_anomaly', action='store_true', default=False) parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) @@ -210,7 +213,7 @@ if __name__ == "__main__": # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done print("\nTraining complete.")