assign image tensors to data_device on creation

This commit is contained in:
GaneshBannur 2024-02-21 16:14:53 +05:30
parent d9fad7b345
commit fc17065cc0
2 changed files with 3 additions and 3 deletions

View File

@ -38,7 +38,7 @@ def loadCam(args, id, cam_info, resolution_scale):
scale = float(global_down) * float(resolution_scale) scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale)) resolution = (int(orig_w / scale), int(orig_h / scale))
resized_image_rgb = PILtoTorch(cam_info.image, resolution) resized_image_rgb = PILtoTorch(cam_info.image, resolution, args.data_device)
gt_image = resized_image_rgb[:3, ...] gt_image = resized_image_rgb[:3, ...]
loaded_mask = None loaded_mask = None

View File

@ -18,9 +18,9 @@ import random
def inverse_sigmoid(x): def inverse_sigmoid(x):
return torch.log(x/(1-x)) return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution): def PILtoTorch(pil_image, resolution, data_device="cuda"):
resized_image_PIL = pil_image.resize(resolution) resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 resized_image = torch.from_numpy(np.array(resized_image_PIL)).to(data_device) / 255.0
if len(resized_image.shape) == 3: if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1) return resized_image.permute(2, 0, 1)
else: else: