mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 00:08:02 +00:00
Debugging behavior
This commit is contained in:
parent
d4fa4779d5
commit
91f16deb45
@ -1 +1 @@
|
|||||||
Subproject commit 73917be7cfd1694e1e61adfa803981925494be5e
|
Subproject commit 8064f52ca233942bdec2d1a1451c026deedd320b
|
7
train.py
7
train.py
@ -28,7 +28,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
TENSORBOARD_FOUND = False
|
TENSORBOARD_FOUND = False
|
||||||
|
|
||||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
|
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
||||||
first_iter = 0
|
first_iter = 0
|
||||||
tb_writer = prepare_output_and_logger(dataset)
|
tb_writer = prepare_output_and_logger(dataset)
|
||||||
gaussians = GaussianModel(dataset.sh_degree)
|
gaussians = GaussianModel(dataset.sh_degree)
|
||||||
@ -78,6 +78,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
||||||
|
|
||||||
# Render
|
# Render
|
||||||
|
if (iteration - 1) == debug_from:
|
||||||
|
pipe.debug = True
|
||||||
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
|
render_pkg = render(viewpoint_cam, gaussians, pipe, background)
|
||||||
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
||||||
|
|
||||||
@ -193,6 +195,7 @@ if __name__ == "__main__":
|
|||||||
pp = PipelineParams(parser)
|
pp = PipelineParams(parser)
|
||||||
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
parser.add_argument('--ip', type=str, default="127.0.0.1")
|
||||||
parser.add_argument('--port', type=int, default=6009)
|
parser.add_argument('--port', type=int, default=6009)
|
||||||
|
parser.add_argument('--debug_from', type=int, default=-1)
|
||||||
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
parser.add_argument('--detect_anomaly', action='store_true', default=False)
|
||||||
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||||
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
||||||
@ -210,7 +213,7 @@ if __name__ == "__main__":
|
|||||||
# Start GUI server, configure and run training
|
# Start GUI server, configure and run training
|
||||||
network_gui.init(args.ip, args.port)
|
network_gui.init(args.ip, args.port)
|
||||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
|
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||||
|
|
||||||
# All done
|
# All done
|
||||||
print("\nTraining complete.")
|
print("\nTraining complete.")
|
||||||
|
Loading…
Reference in New Issue
Block a user