mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-25 00:34:39 +00:00
added support for setting floating point range
Users may want to reduce their memory consumption by using fp16. However, in my tests, such attempts will result in lower quality renders. Some data type conversions did not have any impact, so I removed them completely.
This commit is contained in:
parent
472689c0dc
commit
18eb6d6a0c
@ -194,6 +194,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
|
|||||||
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
|
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
|
||||||
#### --percent_dense
|
#### --percent_dense
|
||||||
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
|
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
|
||||||
|
#### --data_dtype
|
||||||
|
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
<br>
|
<br>
|
||||||
|
@ -53,6 +53,7 @@ class ModelParams(ParamGroup):
|
|||||||
self._resolution = -1
|
self._resolution = -1
|
||||||
self._white_background = False
|
self._white_background = False
|
||||||
self.data_device = "cuda"
|
self.data_device = "cuda"
|
||||||
|
self.data_dtype = "float32"
|
||||||
self.eval = False
|
self.eval = False
|
||||||
super().__init__(parser, "Loading Parameters", sentinel)
|
super().__init__(parser, "Loading Parameters", sentinel)
|
||||||
|
|
||||||
|
@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
|||||||
rotations = rotations,
|
rotations = rotations,
|
||||||
cov3D_precomp = cov3D_precomp)
|
cov3D_precomp = cov3D_precomp)
|
||||||
|
|
||||||
|
# after rasterization, we convert the resulting image to the target dtype
|
||||||
|
# The rasterizer expects parameters as float32, so the result is also float32.
|
||||||
|
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype)
|
||||||
|
|
||||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||||
# They will be excluded from value updates used in the splitting criteria.
|
# They will be excluded from value updates used in the splitting criteria.
|
||||||
return {"render": rendered_image,
|
return {"render": rendered_image,
|
||||||
|
10
render.py
10
render.py
@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
|
|||||||
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
||||||
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
||||||
|
|
||||||
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
|
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gaussians = GaussianModel(dataset.sh_degree)
|
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype)
|
||||||
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
||||||
|
|
||||||
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
||||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
background = torch.tensor(bg_color, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
if not skip_train:
|
if not skip_train:
|
||||||
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
|
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
|
||||||
@ -62,5 +62,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Initialize system state (RNG)
|
# Initialize system state (RNG)
|
||||||
safe_state(args.quiet)
|
safe_state(args.quiet)
|
||||||
|
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
|
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)
|
@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
|||||||
class Camera(nn.Module):
|
class Camera(nn.Module):
|
||||||
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
||||||
image_name, uid,
|
image_name, uid,
|
||||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32
|
||||||
):
|
):
|
||||||
super(Camera, self).__init__()
|
super(Camera, self).__init__()
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ class Camera(nn.Module):
|
|||||||
self.FoVx = FoVx
|
self.FoVx = FoVx
|
||||||
self.FoVy = FoVy
|
self.FoVy = FoVy
|
||||||
self.image_name = image_name
|
self.image_name = image_name
|
||||||
|
self.data_dtype = data_dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.data_device = torch.device(data_device)
|
self.data_device = torch.device(data_device)
|
||||||
@ -36,12 +37,12 @@ class Camera(nn.Module):
|
|||||||
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
||||||
self.data_device = torch.device("cuda")
|
self.data_device = torch.device("cuda")
|
||||||
|
|
||||||
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device)
|
||||||
self.image_width = self.original_image.shape[2]
|
self.image_width = self.original_image.shape[2]
|
||||||
self.image_height = self.original_image.shape[1]
|
self.image_height = self.original_image.shape[1]
|
||||||
|
|
||||||
if gt_alpha_mask is not None:
|
if gt_alpha_mask is not None:
|
||||||
self.original_image *= gt_alpha_mask.to(self.data_device)
|
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device)
|
||||||
else:
|
else:
|
||||||
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
||||||
|
|
||||||
|
@ -23,11 +23,11 @@ from utils.general_utils import strip_symmetric, build_scaling_rotation
|
|||||||
|
|
||||||
class GaussianModel:
|
class GaussianModel:
|
||||||
|
|
||||||
def setup_functions(self):
|
def setup_functions(self, dtype):
|
||||||
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
||||||
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype)
|
||||||
actual_covariance = L @ L.transpose(1, 2)
|
actual_covariance = L @ L.transpose(1, 2)
|
||||||
symm = strip_symmetric(actual_covariance)
|
symm = strip_symmetric(actual_covariance, dtype)
|
||||||
return symm
|
return symm
|
||||||
|
|
||||||
self.scaling_activation = torch.exp
|
self.scaling_activation = torch.exp
|
||||||
@ -41,7 +41,7 @@ class GaussianModel:
|
|||||||
self.rotation_activation = torch.nn.functional.normalize
|
self.rotation_activation = torch.nn.functional.normalize
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, sh_degree : int):
|
def __init__(self, sh_degree : int, dtype=torch.float32):
|
||||||
self.active_sh_degree = 0
|
self.active_sh_degree = 0
|
||||||
self.max_sh_degree = sh_degree
|
self.max_sh_degree = sh_degree
|
||||||
self._xyz = torch.empty(0)
|
self._xyz = torch.empty(0)
|
||||||
@ -56,7 +56,8 @@ class GaussianModel:
|
|||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.percent_dense = 0
|
self.percent_dense = 0
|
||||||
self.spatial_lr_scale = 0
|
self.spatial_lr_scale = 0
|
||||||
self.setup_functions()
|
self.dtype = dtype
|
||||||
|
self.setup_functions(dtype)
|
||||||
|
|
||||||
def capture(self):
|
def capture(self):
|
||||||
return (
|
return (
|
||||||
@ -136,7 +137,7 @@ class GaussianModel:
|
|||||||
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
||||||
rots[:, 0] = 1
|
rots[:, 0] = 1
|
||||||
|
|
||||||
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda"))
|
||||||
|
|
||||||
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
||||||
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
||||||
@ -144,7 +145,7 @@ class GaussianModel:
|
|||||||
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
||||||
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
||||||
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
||||||
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype)
|
||||||
|
|
||||||
def training_setup(self, training_args):
|
def training_setup(self, training_args):
|
||||||
self.percent_dense = training_args.percent_dense
|
self.percent_dense = training_args.percent_dense
|
||||||
@ -246,12 +247,12 @@ class GaussianModel:
|
|||||||
for idx, attr_name in enumerate(rot_names):
|
for idx, attr_name in enumerate(rot_names):
|
||||||
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
||||||
|
|
||||||
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True))
|
||||||
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
||||||
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
||||||
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True))
|
||||||
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True))
|
||||||
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True))
|
||||||
|
|
||||||
self.active_sh_degree = self.max_sh_degree
|
self.active_sh_degree = self.max_sh_degree
|
||||||
|
|
||||||
|
3
train.py
3
train.py
@ -16,7 +16,7 @@ from utils.loss_utils import l1_loss, ssim
|
|||||||
from gaussian_renderer import render, network_gui
|
from gaussian_renderer import render, network_gui
|
||||||
import sys
|
import sys
|
||||||
from scene import Scene, GaussianModel
|
from scene import Scene, GaussianModel
|
||||||
from utils.general_utils import safe_state
|
from utils.general_utils import get_data_dtype, safe_state
|
||||||
import uuid
|
import uuid
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from utils.image_utils import psnr
|
from utils.image_utils import psnr
|
||||||
@ -216,6 +216,7 @@ if __name__ == "__main__":
|
|||||||
# Start GUI server, configure and run training
|
# Start GUI server, configure and run training
|
||||||
network_gui.init(args.ip, args.port)
|
network_gui.init(args.ip, args.port)
|
||||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||||
|
|
||||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||||
|
|
||||||
# All done
|
# All done
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
from scene.cameras import Camera
|
from scene.cameras import Camera
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils.general_utils import PILtoTorch
|
from utils.general_utils import PILtoTorch, get_data_dtype
|
||||||
from utils.graphics_utils import fov2focal
|
from utils.graphics_utils import fov2focal
|
||||||
|
|
||||||
WARNED = False
|
WARNED = False
|
||||||
@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale):
|
|||||||
resolution = (int(orig_w / scale), int(orig_h / scale))
|
resolution = (int(orig_w / scale), int(orig_h / scale))
|
||||||
|
|
||||||
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
||||||
|
|
||||||
|
# resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype))
|
||||||
|
|
||||||
gt_image = resized_image_rgb[:3, ...]
|
gt_image = resized_image_rgb[:3, ...]
|
||||||
loaded_mask = None
|
loaded_mask = None
|
||||||
@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale):
|
|||||||
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
||||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||||
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
|
||||||
|
data_dtype=get_data_dtype(args.data_dtype))
|
||||||
|
|
||||||
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
||||||
camera_list = []
|
camera_list = []
|
||||||
|
@ -61,8 +61,8 @@ def get_expon_lr_func(
|
|||||||
|
|
||||||
return helper
|
return helper
|
||||||
|
|
||||||
def strip_lowerdiag(L):
|
def strip_lowerdiag(L, dtype=torch.float32):
|
||||||
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda")
|
||||||
|
|
||||||
uncertainty[:, 0] = L[:, 0, 0]
|
uncertainty[:, 0] = L[:, 0, 0]
|
||||||
uncertainty[:, 1] = L[:, 0, 1]
|
uncertainty[:, 1] = L[:, 0, 1]
|
||||||
@ -72,8 +72,8 @@ def strip_lowerdiag(L):
|
|||||||
uncertainty[:, 5] = L[:, 2, 2]
|
uncertainty[:, 5] = L[:, 2, 2]
|
||||||
return uncertainty
|
return uncertainty
|
||||||
|
|
||||||
def strip_symmetric(sym):
|
def strip_symmetric(sym, dtype=torch.float32):
|
||||||
return strip_lowerdiag(sym)
|
return strip_lowerdiag(sym, dtype=dtype)
|
||||||
|
|
||||||
def build_rotation(r):
|
def build_rotation(r):
|
||||||
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
||||||
@ -98,8 +98,8 @@ def build_rotation(r):
|
|||||||
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
||||||
return R
|
return R
|
||||||
|
|
||||||
def build_scaling_rotation(s, r):
|
def build_scaling_rotation(s, r, dtype=torch.float32):
|
||||||
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda")
|
||||||
R = build_rotation(r)
|
R = build_rotation(r)
|
||||||
|
|
||||||
L[:,0,0] = s[:,0]
|
L[:,0,0] = s[:,0]
|
||||||
@ -131,3 +131,12 @@ def safe_state(silent):
|
|||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
torch.cuda.set_device(torch.device("cuda:0"))
|
torch.cuda.set_device(torch.device("cuda:0"))
|
||||||
|
|
||||||
|
def get_data_dtype(dtype):
|
||||||
|
if dtype == "float32":
|
||||||
|
return torch.float32
|
||||||
|
elif dtype == "float64":
|
||||||
|
return torch.float64
|
||||||
|
elif dtype == "float16":
|
||||||
|
return torch.float16
|
||||||
|
return torch.float32
|
Loading…
Reference in New Issue
Block a user