Debugging behavior

This commit is contained in:
bkerbl 2023-07-23 22:41:32 +02:00
parent d4fa4779d5
commit 91f16deb45
2 changed files with 6 additions and 3 deletions

@ -1 +1 @@
Subproject commit 73917be7cfd1694e1e61adfa803981925494be5e
Subproject commit 8064f52ca233942bdec2d1a1451c026deedd320b

View File

@ -28,7 +28,7 @@ try:
except ImportError:
TENSORBOARD_FOUND = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
@ -78,6 +78,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
# Render
if (iteration - 1) == debug_from:
pipe.debug = True
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
@ -193,6 +195,7 @@ if __name__ == "__main__":
pp = PipelineParams(parser)
parser.add_argument('--ip', type=str, default="127.0.0.1")
parser.add_argument('--port', type=int, default=6009)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
@ -210,7 +213,7 @@ if __name__ == "__main__":
# Start GUI server, configure and run training
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
# All done
print("\nTraining complete.")