fix Conversion pil to torch introduces potential cuda error

#430
This commit is contained in:
xvdp 2023-11-05 04:15:39 -08:00
parent f934e701b2
commit 4ea5609081
2 changed files with 9 additions and 10 deletions

View File

@ -208,7 +208,7 @@ if __name__ == "__main__":
parser.add_argument("--quiet", action="store_true") parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--start_checkpoint", type=str, default = None) parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument('--cuda_blocking', action='store_true', default=True)
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations) args.save_iterations.append(args.iterations)
@ -217,9 +217,6 @@ if __name__ == "__main__":
# Initialize system state (RNG) # Initialize system state (RNG)
safe_state(args.quiet) safe_state(args.quiet)
# CUDA sometimes fails - option to disable asynchronous operations
if args.cuda_blocking:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# Start GUI server, configure and run training # Start GUI server, configure and run training
network_gui.init(args.ip, args.port) network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly) torch.autograd.set_detect_anomaly(args.detect_anomaly)

View File

@ -18,13 +18,15 @@ import random
def inverse_sigmoid(x): def inverse_sigmoid(x):
return torch.log(x/(1-x)) return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution): def PILtoTorch(pil_image, resolution, pin_memory=True):
resized_image_PIL = pil_image.resize(resolution) resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 resized_image = torch.from_numpy(np.array(resized_image_PIL, dtype=np.float32)) / 255.0
if len(resized_image.shape) == 3: if resized_image.ndim == 2:
return resized_image.permute(2, 0, 1) resized_image = resized_image[None]
else: resized_image = resized_image.permute(2, 0, 1).contiguous()
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) if pin_memory:
resized_image.pin_memory = True
return resized_image
def get_expon_lr_func( def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000