diff --git a/README.md b/README.md index 28f2dee..14278cd 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,8 @@ python train.py -s Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. #### --data_dtype The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default. + #### --store_images_as_uint8 + Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 3cad0b3..fdfe3fc 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -54,6 +54,7 @@ class ModelParams(ParamGroup): self._white_background = False self.data_device = "cuda" self.data_dtype = "float32" + self.store_images_as_uint8 = False self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index e8af831..3efab6e 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -94,7 +94,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, # after rasterization, we convert the resulting image to the target dtype # The rasterizer expects parameters as float32, so the result is also float32. - rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) + rendered_image = rendered_image.to(viewpoint_camera.data_dtype) # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. diff --git a/render.py b/render.py index 70f85cb..fc54831 100644 --- a/render.py +++ b/render.py @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) -def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): with torch.no_grad(): - gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) + gaussians = GaussianModel(dataset.sh_degree) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=dtype, device="cuda") + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") if not skip_train: render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) @@ -63,6 +63,4 @@ if __name__ == "__main__": # Initialize system state (RNG) safe_state(args.quiet) - dtype = torch.float32 - - render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype) \ No newline at end of file + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index 0609d0a..5264a04 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,8 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32, + store_images_as_uint8=True, ): super(Camera, self).__init__() @@ -37,14 +38,18 @@ class Camera(nn.Module): print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) - self.image_width = self.original_image.shape[2] - self.image_height = self.original_image.shape[1] + self.store_images_as_uint8 = store_images_as_uint8 - if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) - else: - self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + self._original_image = image.to(self.data_device) + self._gt_alpha_mask = gt_alpha_mask + if self._gt_alpha_mask is not None: + self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device) + + if not store_images_as_uint8: + self._original_image = self.convert_image(self._original_image) + + self.image_width = self._original_image.shape[2] + self.image_height = self._original_image.shape[1] self.zfar = 100.0 self.znear = 0.01 @@ -57,6 +62,23 @@ class Camera(nn.Module): self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] + def convert_image(self, image): + image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype) + gt_alpha_mask = self._gt_alpha_mask + + if gt_alpha_mask is not None: + gt_alpha_mask = gt_alpha_mask / 255.0 + image *= gt_alpha_mask.to(self.data_dtype) + + return image + + @property + def original_image(self): + if self.store_images_as_uint8: + return self.convert_image(self._original_image) + else: + return self._original_image + class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 6e762a5..400e7b6 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -52,7 +52,8 @@ def loadCam(args, id, cam_info, resolution_scale): FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, image_name=cam_info.image_name, uid=id, data_device=args.data_device, - data_dtype=get_data_dtype(args.data_dtype)) + data_dtype=get_data_dtype(args.data_dtype), + store_images_as_uint8=args.store_images_as_uint8) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index ed7f0a6..4b0c53a 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -20,7 +20,7 @@ def inverse_sigmoid(x): def PILtoTorch(pil_image, resolution): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL))# / 255.0 if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: