From fc17065cc03c40d302d8cd748aadfecf0d91284b Mon Sep 17 00:00:00 2001 From: GaneshBannur <76954924+GaneshBannur@users.noreply.github.com> Date: Wed, 21 Feb 2024 16:14:53 +0530 Subject: [PATCH] assign image tensors to data_device on creation --- utils/camera_utils.py | 2 +- utils/general_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..d90d442 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -38,7 +38,7 @@ def loadCam(args, id, cam_info, resolution_scale): scale = float(global_down) * float(resolution_scale) resolution = (int(orig_w / scale), int(orig_h / scale)) - resized_image_rgb = PILtoTorch(cam_info.image, resolution) + resized_image_rgb = PILtoTorch(cam_info.image, resolution, args.data_device) gt_image = resized_image_rgb[:3, ...] loaded_mask = None diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..50d6928 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -18,9 +18,9 @@ import random def inverse_sigmoid(x): return torch.log(x/(1-x)) -def PILtoTorch(pil_image, resolution): +def PILtoTorch(pil_image, resolution, data_device="cuda"): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL)).to(data_device) / 255.0 if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: