mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-12-02 00:55:46 +00:00
Added option to disable asynchronous operations which can cause cuda to fails,
network_gui.listener can block resources, clean up on break.
This commit is contained in:
parent
a0dc5af86f
commit
2311f4e764
16
train.py
16
train.py
@ -48,7 +48,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
ema_loss_for_log = 0.0
|
ema_loss_for_log = 0.0
|
||||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||||
first_iter += 1
|
first_iter += 1
|
||||||
for iteration in range(first_iter, opt.iterations + 1):
|
for iteration in range(first_iter, opt.iterations + 1):
|
||||||
if network_gui.conn == None:
|
if network_gui.conn == None:
|
||||||
network_gui.try_connect()
|
network_gui.try_connect()
|
||||||
while network_gui.conn != None:
|
while network_gui.conn != None:
|
||||||
@ -62,7 +62,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
|
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
network_gui.conn.close()
|
||||||
network_gui.conn = None
|
network_gui.conn = None
|
||||||
|
network_gui.listener.close()
|
||||||
|
network_gui.listener = None
|
||||||
|
|
||||||
iter_start.record()
|
iter_start.record()
|
||||||
|
|
||||||
@ -159,7 +162,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
|
|||||||
# Report test and samples of training set
|
# Report test and samples of training set
|
||||||
if iteration in testing_iterations:
|
if iteration in testing_iterations:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
|
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
|
||||||
{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
|
{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
|
||||||
|
|
||||||
for config in validation_configs:
|
for config in validation_configs:
|
||||||
@ -202,18 +205,23 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--quiet", action="store_true")
|
parser.add_argument("--quiet", action="store_true")
|
||||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||||
|
parser.add_argument('--cuda_blocking', action='store_true', default=True)
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
args.save_iterations.append(args.iterations)
|
args.save_iterations.append(args.iterations)
|
||||||
|
|
||||||
print("Optimizing " + args.model_path)
|
print("Optimizing " + args.model_path)
|
||||||
|
|
||||||
# Initialize system state (RNG)
|
# Initialize system state (RNG)
|
||||||
safe_state(args.quiet)
|
safe_state(args.quiet)
|
||||||
|
|
||||||
|
# CUDA sometimes fails - option to disable asynchronous operations
|
||||||
|
if args.cuda_blocking:
|
||||||
|
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
|
||||||
# Start GUI server, configure and run training
|
# Start GUI server, configure and run training
|
||||||
network_gui.init(args.ip, args.port)
|
network_gui.init(args.ip, args.port)
|
||||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations,
|
||||||
|
args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||||
|
|
||||||
# All done
|
# All done
|
||||||
print("\nTraining complete.")
|
print("\nTraining complete.")
|
||||||
|
Loading…
Reference in New Issue
Block a user