added support for setting floating point range

Users may want to reduce their memory consumption by using fp16.
However, in my tests, such attempts will result in lower quality renders.
Some data type conversions did not have any impact, so I removed them completely.
This commit is contained in:
Stefan Saraev 2024-04-24 18:37:06 +03:00 committed by Matei Barbu
parent 472689c0dc
commit 18eb6d6a0c
9 changed files with 53 additions and 29 deletions

View File

@ -194,6 +194,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
#### --percent_dense
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
#### --data_dtype
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
</details>
<br>

View File

@ -53,6 +53,7 @@ class ModelParams(ParamGroup):
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.data_dtype = "float32"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

View File

@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
# after rasterization, we convert the resulting image to the target dtype
# The rasterizer expects parameters as float32, so the result is also float32.
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,

View File

@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
background = torch.tensor(bg_color, dtype=dtype, device="cuda")
if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
@ -62,5 +62,7 @@ if __name__ == "__main__":
# Initialize system state (RNG)
safe_state(args.quiet)
dtype = torch.float32
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)

View File

@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32
):
super(Camera, self).__init__()
@ -28,6 +28,7 @@ class Camera(nn.Module):
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.data_dtype = data_dtype
try:
self.data_device = torch.device(data_device)
@ -36,12 +37,12 @@ class Camera(nn.Module):
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)

View File

@ -23,11 +23,11 @@ from utils.general_utils import strip_symmetric, build_scaling_rotation
class GaussianModel:
def setup_functions(self):
def setup_functions(self, dtype):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
symm = strip_symmetric(actual_covariance, dtype)
return symm
self.scaling_activation = torch.exp
@ -41,7 +41,7 @@ class GaussianModel:
self.rotation_activation = torch.nn.functional.normalize
def __init__(self, sh_degree : int):
def __init__(self, sh_degree : int, dtype=torch.float32):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
@ -56,7 +56,8 @@ class GaussianModel:
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self.setup_functions()
self.dtype = dtype
self.setup_functions(dtype)
def capture(self):
return (
@ -136,7 +137,7 @@ class GaussianModel:
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda"))
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
@ -144,7 +145,7 @@ class GaussianModel:
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype)
def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense
@ -246,12 +247,12 @@ class GaussianModel:
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True))
self.active_sh_degree = self.max_sh_degree

View File

@ -16,7 +16,7 @@ from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
from utils.general_utils import get_data_dtype, safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
@ -216,6 +216,7 @@ if __name__ == "__main__":
# Start GUI server, configure and run training
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
# All done

View File

@ -11,7 +11,7 @@
from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.general_utils import PILtoTorch, get_data_dtype
from utils.graphics_utils import fov2focal
WARNED = False
@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale):
resolution = (int(orig_w / scale), int(orig_h / scale))
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
# resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype))
gt_image = resized_image_rgb[:3, ...]
loaded_mask = None
@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
data_dtype=get_data_dtype(args.data_dtype))
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []

View File

@ -61,8 +61,8 @@ def get_expon_lr_func(
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
def strip_lowerdiag(L, dtype=torch.float32):
uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
@ -72,8 +72,8 @@ def strip_lowerdiag(L):
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def strip_symmetric(sym, dtype=torch.float32):
return strip_lowerdiag(sym, dtype=dtype)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
@ -98,8 +98,8 @@ def build_rotation(r):
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
def build_scaling_rotation(s, r, dtype=torch.float32):
L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
@ -131,3 +131,12 @@ def safe_state(silent):
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
def get_data_dtype(dtype):
if dtype == "float32":
return torch.float32
elif dtype == "float64":
return torch.float64
elif dtype == "float16":
return torch.float16
return torch.float32