Added option to disable asynchronous operations which can cause cuda to fails,

network_gui.listener can block resources, clean up on break.
This commit is contained in:
xvdp 2023-10-31 04:48:52 -07:00
parent a0dc5af86f
commit 2311f4e764

View File

@ -62,7 +62,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if do_training and ((iteration < int(opt.iterations)) or not keep_alive): if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
break break
except Exception as e: except Exception as e:
network_gui.conn.close()
network_gui.conn = None network_gui.conn = None
network_gui.listener.close()
network_gui.listener = None
iter_start.record() iter_start.record()
@ -202,6 +205,7 @@ if __name__ == "__main__":
parser.add_argument("--quiet", action="store_true") parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None) parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument('--cuda_blocking', action='store_true', default=True)
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations) args.save_iterations.append(args.iterations)
@ -210,10 +214,14 @@ if __name__ == "__main__":
# Initialize system state (RNG) # Initialize system state (RNG)
safe_state(args.quiet) safe_state(args.quiet)
# CUDA sometimes fails - option to disable asynchronous operations
if args.cuda_blocking:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# Start GUI server, configure and run training # Start GUI server, configure and run training
network_gui.init(args.ip, args.port) network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly) torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations,
args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
# All done # All done
print("\nTraining complete.") print("\nTraining complete.")