This commit is contained in:
Stefan Saraev 2024-09-02 14:35:15 -06:00 committed by GitHub
commit 99a3ce85e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 55 additions and 11 deletions

View File

@ -206,6 +206,10 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
#### --percent_dense
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
#### --data_dtype
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
#### --store_images_as_uint8
Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
</details>
<br>

View File

@ -53,6 +53,8 @@ class ModelParams(ParamGroup):
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.data_dtype = "float32"
self.store_images_as_uint8 = False
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

View File

@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
# after rasterization, we convert the resulting image to the target dtype
# The rasterizer expects parameters as float32, so the result is also float32.
rendered_image = rendered_image.to(viewpoint_camera.data_dtype)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,

View File

@ -17,7 +17,8 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
store_images_as_uint8=True,
):
super(Camera, self).__init__()
@ -28,6 +29,7 @@ class Camera(nn.Module):
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.data_dtype = data_dtype
try:
self.data_device = torch.device(data_device)
@ -36,14 +38,18 @@ class Camera(nn.Module):
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
self.store_images_as_uint8 = store_images_as_uint8
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self._original_image = image.to(self.data_device)
self._gt_alpha_mask = gt_alpha_mask
if self._gt_alpha_mask is not None:
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
if not store_images_as_uint8:
self._original_image = self.convert_image(self._original_image)
self.image_width = self._original_image.shape[2]
self.image_height = self._original_image.shape[1]
self.zfar = 100.0
self.znear = 0.01
@ -56,6 +62,23 @@ class Camera(nn.Module):
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
def convert_image(self, image):
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
gt_alpha_mask = self._gt_alpha_mask
if gt_alpha_mask is not None:
gt_alpha_mask = gt_alpha_mask / 255.0
image *= gt_alpha_mask.to(self.data_dtype)
return image
@property
def original_image(self):
if self.store_images_as_uint8:
return self.convert_image(self._original_image)
else:
return self._original_image
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width

View File

@ -11,7 +11,7 @@
from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.general_utils import PILtoTorch, get_data_dtype
from utils.graphics_utils import fov2focal
WARNED = False
@ -49,7 +49,9 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
data_dtype=get_data_dtype(args.data_dtype),
store_images_as_uint8=args.store_images_as_uint8)
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []

View File

@ -20,7 +20,7 @@ def inverse_sigmoid(x):
def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
resized_image = torch.from_numpy(np.array(resized_image_PIL))
if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1)
else:
@ -131,3 +131,12 @@ def safe_state(silent):
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
def get_data_dtype(dtype):
if dtype == "float32":
return torch.float32
elif dtype == "float64":
return torch.float64
elif dtype == "float16":
return torch.float16
return torch.float32