From 18eb6d6a0c363cf1023fd7b74ffd0dd92bdc35cd Mon Sep 17 00:00:00 2001 From: Stefan Saraev Date: Wed, 24 Apr 2024 18:37:06 +0300 Subject: [PATCH] added support for setting floating point range Users may want to reduce their memory consumption by using fp16. However, in my tests, such attempts will result in lower quality renders. Some data type conversions did not have any impact, so I removed them completely. --- README.md | 2 ++ arguments/__init__.py | 1 + gaussian_renderer/__init__.py | 4 ++++ render.py | 10 ++++++---- scene/cameras.py | 7 ++++--- scene/gaussian_model.py | 27 ++++++++++++++------------- train.py | 3 ++- utils/camera_utils.py | 7 +++++-- utils/general_utils.py | 21 +++++++++++++++------ 9 files changed, 53 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 4cbd332..28f2dee 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,8 @@ python train.py -s Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. #### --percent_dense Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + #### --data_dtype + The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a55..3cad0b3 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -53,6 +53,7 @@ class ModelParams(ParamGroup): self._resolution = -1 self._white_background = False self.data_device = "cuda" + self.data_dtype = "float32" self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index f74e336..e8af831 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, rotations = rotations, cov3D_precomp = cov3D_precomp) + # after rasterization, we convert the resulting image to the target dtype + # The rasterizer expects parameters as float32, so the result is also float32. + rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return {"render": rendered_image, diff --git a/render.py b/render.py index fc6b82d..70f85cb 100644 --- a/render.py +++ b/render.py @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) -def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): with torch.no_grad(): - gaussians = GaussianModel(dataset.sh_degree) + gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + background = torch.tensor(bg_color, dtype=dtype, device="cuda") if not skip_train: render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) @@ -62,5 +62,7 @@ if __name__ == "__main__": # Initialize system state (RNG) safe_state(args.quiet) + + dtype = torch.float32 - render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index abf6e52..0609d0a 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 ): super(Camera, self).__init__() @@ -28,6 +28,7 @@ class Camera(nn.Module): self.FoVx = FoVx self.FoVy = FoVy self.image_name = image_name + self.data_dtype = data_dtype try: self.data_device = torch.device(data_device) @@ -36,12 +37,12 @@ class Camera(nn.Module): print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) + self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) else: self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 632a1e8..f790558 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -23,11 +23,11 @@ from utils.general_utils import strip_symmetric, build_scaling_rotation class GaussianModel: - def setup_functions(self): + def setup_functions(self, dtype): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): - L = build_scaling_rotation(scaling_modifier * scaling, rotation) + L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype) actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance) + symm = strip_symmetric(actual_covariance, dtype) return symm self.scaling_activation = torch.exp @@ -41,7 +41,7 @@ class GaussianModel: self.rotation_activation = torch.nn.functional.normalize - def __init__(self, sh_degree : int): + def __init__(self, sh_degree : int, dtype=torch.float32): self.active_sh_degree = 0 self.max_sh_degree = sh_degree self._xyz = torch.empty(0) @@ -56,7 +56,8 @@ class GaussianModel: self.optimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 - self.setup_functions() + self.dtype = dtype + self.setup_functions(dtype) def capture(self): return ( @@ -136,7 +137,7 @@ class GaussianModel: rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 - opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) @@ -144,7 +145,7 @@ class GaussianModel: self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) - self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype) def training_setup(self, training_args): self.percent_dense = training_args.percent_dense @@ -246,12 +247,12 @@ class GaussianModel: for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) - self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) - self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) - self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True)) self.active_sh_degree = self.max_sh_degree diff --git a/train.py b/train.py index 5d819b3..7435218 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from utils.loss_utils import l1_loss, ssim from gaussian_renderer import render, network_gui import sys from scene import Scene, GaussianModel -from utils.general_utils import safe_state +from utils.general_utils import get_data_dtype, safe_state import uuid from tqdm import tqdm from utils.image_utils import psnr @@ -216,6 +216,7 @@ if __name__ == "__main__": # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..6e762a5 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -11,7 +11,7 @@ from scene.cameras import Camera import numpy as np -from utils.general_utils import PILtoTorch +from utils.general_utils import PILtoTorch, get_data_dtype from utils.graphics_utils import fov2focal WARNED = False @@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale): resolution = (int(orig_w / scale), int(orig_h / scale)) resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + # resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype)) gt_image = resized_image_rgb[:3, ...] loaded_mask = None @@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale): return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + data_dtype=get_data_dtype(args.data_dtype)) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..ed7f0a6 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -61,8 +61,8 @@ def get_expon_lr_func( return helper -def strip_lowerdiag(L): - uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") +def strip_lowerdiag(L, dtype=torch.float32): + uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda") uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] @@ -72,8 +72,8 @@ def strip_lowerdiag(L): uncertainty[:, 5] = L[:, 2, 2] return uncertainty -def strip_symmetric(sym): - return strip_lowerdiag(sym) +def strip_symmetric(sym, dtype=torch.float32): + return strip_lowerdiag(sym, dtype=dtype) def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) @@ -98,8 +98,8 @@ def build_rotation(r): R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R -def build_scaling_rotation(s, r): - L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") +def build_scaling_rotation(s, r, dtype=torch.float32): + L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda") R = build_rotation(r) L[:,0,0] = s[:,0] @@ -131,3 +131,12 @@ def safe_state(silent): np.random.seed(0) torch.manual_seed(0) torch.cuda.set_device(torch.device("cuda:0")) + +def get_data_dtype(dtype): + if dtype == "float32": + return torch.float32 + elif dtype == "float64": + return torch.float64 + elif dtype == "float16": + return torch.float16 + return torch.float32 \ No newline at end of file