gaussian-splatting/scene/cameras.py
Stefan Saraev b5a5f72eda load speedup: refactored image loading
Images are now loaded on the target device as uint8s.
Then they are converted to the target data type (eg. fp32 or fp16).
This speeds up the loading time.

Also, users can opt to store the image as uint8 or as target data type.
This will further reduce memory usage.
2024-05-15 22:32:58 +03:00

95 lines
3.3 KiB
Python

#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
from torch import nn
import numpy as np
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
store_images_as_uint8=True,
):
super(Camera, self).__init__()
self.uid = uid
self.colmap_id = colmap_id
self.R = R
self.T = T
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.data_dtype = data_dtype
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.store_images_as_uint8 = store_images_as_uint8
self._original_image = image.to(self.data_device)
self._gt_alpha_mask = gt_alpha_mask
if self._gt_alpha_mask is not None:
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
if not store_images_as_uint8:
self._original_image = self.convert_image(self._original_image)
self.image_width = self._original_image.shape[2]
self.image_height = self._original_image.shape[1]
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
def convert_image(self, image):
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
gt_alpha_mask = self._gt_alpha_mask
if gt_alpha_mask is not None:
gt_alpha_mask = gt_alpha_mask / 255.0
image *= gt_alpha_mask.to(self.data_dtype)
return image
@property
def original_image(self):
if self.store_images_as_uint8:
return self.convert_image(self._original_image)
else:
return self._original_image
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width
self.image_height = height
self.FoVy = fovy
self.FoVx = fovx
self.znear = znear
self.zfar = zfar
self.world_view_transform = world_view_transform
self.full_proj_transform = full_proj_transform
view_inv = torch.inverse(self.world_view_transform)
self.camera_center = view_inv[3][:3]