diff --git a/README.md b/README.md index 4cbd332..28f2dee 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,8 @@ python train.py -s Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. #### --percent_dense Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + #### --data_dtype + The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a55..3cad0b3 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -53,6 +53,7 @@ class ModelParams(ParamGroup): self._resolution = -1 self._white_background = False self.data_device = "cuda" + self.data_dtype = "float32" self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index f74e336..e8af831 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, rotations = rotations, cov3D_precomp = cov3D_precomp) + # after rasterization, we convert the resulting image to the target dtype + # The rasterizer expects parameters as float32, so the result is also float32. + rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return {"render": rendered_image, diff --git a/render.py b/render.py index fc6b82d..70f85cb 100644 --- a/render.py +++ b/render.py @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) -def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): with torch.no_grad(): - gaussians = GaussianModel(dataset.sh_degree) + gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + background = torch.tensor(bg_color, dtype=dtype, device="cuda") if not skip_train: render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) @@ -62,5 +62,7 @@ if __name__ == "__main__": # Initialize system state (RNG) safe_state(args.quiet) + + dtype = torch.float32 - render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index abf6e52..0609d0a 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 ): super(Camera, self).__init__() @@ -28,6 +28,7 @@ class Camera(nn.Module): self.FoVx = FoVx self.FoVy = FoVy self.image_name = image_name + self.data_dtype = data_dtype try: self.data_device = torch.device(data_device) @@ -36,12 +37,12 @@ class Camera(nn.Module): print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) + self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) else: self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 632a1e8..f790558 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -23,11 +23,11 @@ from utils.general_utils import strip_symmetric, build_scaling_rotation class GaussianModel: - def setup_functions(self): + def setup_functions(self, dtype): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): - L = build_scaling_rotation(scaling_modifier * scaling, rotation) + L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype) actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance) + symm = strip_symmetric(actual_covariance, dtype) return symm self.scaling_activation = torch.exp @@ -41,7 +41,7 @@ class GaussianModel: self.rotation_activation = torch.nn.functional.normalize - def __init__(self, sh_degree : int): + def __init__(self, sh_degree : int, dtype=torch.float32): self.active_sh_degree = 0 self.max_sh_degree = sh_degree self._xyz = torch.empty(0) @@ -56,7 +56,8 @@ class GaussianModel: self.optimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 - self.setup_functions() + self.dtype = dtype + self.setup_functions(dtype) def capture(self): return ( @@ -136,7 +137,7 @@ class GaussianModel: rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 - opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) @@ -144,7 +145,7 @@ class GaussianModel: self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) - self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype) def training_setup(self, training_args): self.percent_dense = training_args.percent_dense @@ -246,12 +247,12 @@ class GaussianModel: for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) - self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) - self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) - self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True)) self.active_sh_degree = self.max_sh_degree diff --git a/train.py b/train.py index 5d819b3..7435218 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from utils.loss_utils import l1_loss, ssim from gaussian_renderer import render, network_gui import sys from scene import Scene, GaussianModel -from utils.general_utils import safe_state +from utils.general_utils import get_data_dtype, safe_state import uuid from tqdm import tqdm from utils.image_utils import psnr @@ -216,6 +216,7 @@ if __name__ == "__main__": # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..6e762a5 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -11,7 +11,7 @@ from scene.cameras import Camera import numpy as np -from utils.general_utils import PILtoTorch +from utils.general_utils import PILtoTorch, get_data_dtype from utils.graphics_utils import fov2focal WARNED = False @@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale): resolution = (int(orig_w / scale), int(orig_h / scale)) resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + # resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype)) gt_image = resized_image_rgb[:3, ...] loaded_mask = None @@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale): return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + data_dtype=get_data_dtype(args.data_dtype)) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c082..ed7f0a6 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -61,8 +61,8 @@ def get_expon_lr_func( return helper -def strip_lowerdiag(L): - uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") +def strip_lowerdiag(L, dtype=torch.float32): + uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda") uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] @@ -72,8 +72,8 @@ def strip_lowerdiag(L): uncertainty[:, 5] = L[:, 2, 2] return uncertainty -def strip_symmetric(sym): - return strip_lowerdiag(sym) +def strip_symmetric(sym, dtype=torch.float32): + return strip_lowerdiag(sym, dtype=dtype) def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) @@ -98,8 +98,8 @@ def build_rotation(r): R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R -def build_scaling_rotation(s, r): - L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") +def build_scaling_rotation(s, r, dtype=torch.float32): + L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda") R = build_rotation(r) L[:,0,0] = s[:,0] @@ -131,3 +131,12 @@ def safe_state(silent): np.random.seed(0) torch.manual_seed(0) torch.cuda.set_device(torch.device("cuda:0")) + +def get_data_dtype(dtype): + if dtype == "float32": + return torch.float32 + elif dtype == "float64": + return torch.float64 + elif dtype == "float16": + return torch.float16 + return torch.float32 \ No newline at end of file