diff --git a/train.py b/train.py index 36faf0d..cfaf369 100644 --- a/train.py +++ b/train.py @@ -48,7 +48,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi ema_loss_for_log = 0.0 progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") first_iter += 1 - for iteration in range(first_iter, opt.iterations + 1): + for iteration in range(first_iter, opt.iterations + 1): if network_gui.conn == None: network_gui.try_connect() while network_gui.conn != None: @@ -62,7 +62,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi if do_training and ((iteration < int(opt.iterations)) or not keep_alive): break except Exception as e: + network_gui.conn.close() network_gui.conn = None + network_gui.listener.close() + network_gui.listener = None iter_start.record() @@ -159,7 +162,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i # Report test and samples of training set if iteration in testing_iterations: torch.cuda.empty_cache() - validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, + validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]}) for config in validation_configs: @@ -202,18 +205,23 @@ if __name__ == "__main__": parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default = None) + parser.add_argument('--cuda_blocking', action='store_true', default=True) args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) - + print("Optimizing " + args.model_path) # Initialize system state (RNG) safe_state(args.quiet) + # CUDA sometimes fails - option to disable asynchronous operations + if args.cuda_blocking: + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, + args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done print("\nTraining complete.")