From 4ea5609081b9c2df83f7774c615972b4bf21a5d1 Mon Sep 17 00:00:00 2001 From: xvdp Date: Sun, 5 Nov 2023 04:15:39 -0800 Subject: [PATCH] fix Conversion pil to torch introduces potential cuda error #430 --- train.py | 5 +---- utils/general_utils.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index 7b73515..d57acae 100644 --- a/train.py +++ b/train.py @@ -208,7 +208,7 @@ if __name__ == "__main__": parser.add_argument("--quiet", action="store_true") parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--start_checkpoint", type=str, default = None) - parser.add_argument('--cuda_blocking', action='store_true', default=True) + args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) @@ -217,9 +217,6 @@ if __name__ == "__main__": # Initialize system state (RNG) safe_state(args.quiet) - # CUDA sometimes fails - option to disable asynchronous operations - if args.cuda_blocking: - os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..43f56ee 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -18,13 +18,15 @@ import random def inverse_sigmoid(x): return torch.log(x/(1-x)) -def PILtoTorch(pil_image, resolution): +def PILtoTorch(pil_image, resolution, pin_memory=True): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 - if len(resized_image.shape) == 3: - return resized_image.permute(2, 0, 1) - else: - return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + resized_image = torch.from_numpy(np.array(resized_image_PIL, dtype=np.float32)) / 255.0 + if resized_image.ndim == 2: + resized_image = resized_image[None] + resized_image = resized_image.permute(2, 0, 1).contiguous() + if pin_memory: + resized_image.pin_memory = True + return resized_image def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000