diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..d90d442 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -38,7 +38,7 @@ def loadCam(args, id, cam_info, resolution_scale): scale = float(global_down) * float(resolution_scale) resolution = (int(orig_w / scale), int(orig_h / scale)) - resized_image_rgb = PILtoTorch(cam_info.image, resolution) + resized_image_rgb = PILtoTorch(cam_info.image, resolution, args.data_device) gt_image = resized_image_rgb[:3, ...] loaded_mask = None diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..50d6928 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -18,9 +18,9 @@ import random def inverse_sigmoid(x): return torch.log(x/(1-x)) -def PILtoTorch(pil_image, resolution): +def PILtoTorch(pil_image, resolution, data_device="cuda"): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL)).to(data_device) / 255.0 if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: