mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-21 15:57:45 +00:00
Debugging behavior
This commit is contained in:
parent
d4fa4779d5
commit
91f16deb45
@ -1 +1 @@
|
||||
Subproject commit 73917be7cfd1694e1e61adfa803981925494be5e
|
||||
Subproject commit 8064f52ca233942bdec2d1a1451c026deedd320b
|
7
train.py
7
train.py
@ -28,7 +28,7 @@ try:
|
||||
except ImportError:
|
||||
TENSORBOARD_FOUND = False
|
||||
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
||||
first_iter = 0
|
||||
tb_writer = prepare_output_and_logger(dataset)
|
||||
gaussians = GaussianModel(dataset.sh_degree)
|
||||
@ -78,6 +78,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
||||
|
||||
# Render
|
||||
if (iteration - 1) == debug_from:
|
||||
pipe.debug = True
|
||||
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
|
||||
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
||||
|
||||
@ -193,6 +195,7 @@ if __name__ == "__main__":
|
||||
pp = PipelineParams(parser)
|
||||
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
||||
parser.add_argument('--port', type=int, default=6009)
|
||||
parser.add_argument('--debug_from', type=int, default=-1)
|
||||
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
||||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||
@ -210,7 +213,7 @@ if __name__ == "__main__":
|
||||
# Start GUI server, configure and run training
|
||||
network_gui.init(args.ip, args.port)
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
|
||||
# All done
|
||||
print("\nTraining complete.")
|
||||
|
Loading…
Reference in New Issue
Block a user