mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
assign image tensors to data_device on creation
This commit is contained in:
parent
d9fad7b345
commit
fc17065cc0
@ -38,7 +38,7 @@ def loadCam(args, id, cam_info, resolution_scale):
|
|||||||
scale = float(global_down) * float(resolution_scale)
|
scale = float(global_down) * float(resolution_scale)
|
||||||
resolution = (int(orig_w / scale), int(orig_h / scale))
|
resolution = (int(orig_w / scale), int(orig_h / scale))
|
||||||
|
|
||||||
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
resized_image_rgb = PILtoTorch(cam_info.image, resolution, args.data_device)
|
||||||
|
|
||||||
gt_image = resized_image_rgb[:3, ...]
|
gt_image = resized_image_rgb[:3, ...]
|
||||||
loaded_mask = None
|
loaded_mask = None
|
||||||
|
@ -18,9 +18,9 @@ import random
|
|||||||
def inverse_sigmoid(x):
|
def inverse_sigmoid(x):
|
||||||
return torch.log(x/(1-x))
|
return torch.log(x/(1-x))
|
||||||
|
|
||||||
def PILtoTorch(pil_image, resolution):
|
def PILtoTorch(pil_image, resolution, data_device="cuda"):
|
||||||
resized_image_PIL = pil_image.resize(resolution)
|
resized_image_PIL = pil_image.resize(resolution)
|
||||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
resized_image = torch.from_numpy(np.array(resized_image_PIL)).to(data_device) / 255.0
|
||||||
if len(resized_image.shape) == 3:
|
if len(resized_image.shape) == 3:
|
||||||
return resized_image.permute(2, 0, 1)
|
return resized_image.permute(2, 0, 1)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user