mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 16:28:32 +00:00
b5a5f72eda
Images are now loaded on the target device as uint8s. Then they are converted to the target data type (eg. fp32 or fp16). This speeds up the loading time. Also, users can opt to store the image as uint8 or as target data type. This will further reduce memory usage.
95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
#
|
|
# Copyright (C) 2023, Inria
|
|
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
|
# All rights reserved.
|
|
#
|
|
# This software is free for non-commercial, research and evaluation use
|
|
# under the terms of the LICENSE.md file.
|
|
#
|
|
# For inquiries contact george.drettakis@inria.fr
|
|
#
|
|
|
|
import torch
|
|
from torch import nn
|
|
import numpy as np
|
|
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
|
|
|
class Camera(nn.Module):
|
|
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
|
image_name, uid,
|
|
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
|
|
store_images_as_uint8=True,
|
|
):
|
|
super(Camera, self).__init__()
|
|
|
|
self.uid = uid
|
|
self.colmap_id = colmap_id
|
|
self.R = R
|
|
self.T = T
|
|
self.FoVx = FoVx
|
|
self.FoVy = FoVy
|
|
self.image_name = image_name
|
|
self.data_dtype = data_dtype
|
|
|
|
try:
|
|
self.data_device = torch.device(data_device)
|
|
except Exception as e:
|
|
print(e)
|
|
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
|
self.data_device = torch.device("cuda")
|
|
|
|
self.store_images_as_uint8 = store_images_as_uint8
|
|
|
|
self._original_image = image.to(self.data_device)
|
|
self._gt_alpha_mask = gt_alpha_mask
|
|
if self._gt_alpha_mask is not None:
|
|
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
|
|
|
|
if not store_images_as_uint8:
|
|
self._original_image = self.convert_image(self._original_image)
|
|
|
|
self.image_width = self._original_image.shape[2]
|
|
self.image_height = self._original_image.shape[1]
|
|
|
|
self.zfar = 100.0
|
|
self.znear = 0.01
|
|
|
|
self.trans = trans
|
|
self.scale = scale
|
|
|
|
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
|
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
|
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
|
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
|
|
|
def convert_image(self, image):
|
|
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
|
|
gt_alpha_mask = self._gt_alpha_mask
|
|
|
|
if gt_alpha_mask is not None:
|
|
gt_alpha_mask = gt_alpha_mask / 255.0
|
|
image *= gt_alpha_mask.to(self.data_dtype)
|
|
|
|
return image
|
|
|
|
@property
|
|
def original_image(self):
|
|
if self.store_images_as_uint8:
|
|
return self.convert_image(self._original_image)
|
|
else:
|
|
return self._original_image
|
|
|
|
class MiniCam:
|
|
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
|
self.image_width = width
|
|
self.image_height = height
|
|
self.FoVy = fovy
|
|
self.FoVx = fovx
|
|
self.znear = znear
|
|
self.zfar = zfar
|
|
self.world_view_transform = world_view_transform
|
|
self.full_proj_transform = full_proj_transform
|
|
view_inv = torch.inverse(self.world_view_transform)
|
|
self.camera_center = view_inv[3][:3]
|
|
|