mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 00:08:02 +00:00
fix exposure when loading from ply
This commit is contained in:
parent
21301643a4
commit
1152adc9d3
@ -21,7 +21,7 @@ from argparse import ArgumentParser
|
|||||||
from arguments import ModelParams, PipelineParams, get_combined_args
|
from arguments import ModelParams, PipelineParams, get_combined_args
|
||||||
from gaussian_renderer import GaussianModel
|
from gaussian_renderer import GaussianModel
|
||||||
|
|
||||||
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
|
def render_set(model_path, name, iteration, views, gaussians, pipeline, background, train_test_exp):
|
||||||
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
||||||
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
|
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
|
|||||||
makedirs(gts_path, exist_ok=True)
|
makedirs(gts_path, exist_ok=True)
|
||||||
|
|
||||||
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
||||||
rendering = render(view, gaussians, pipeline, background)["render"]
|
rendering = render(view, gaussians, pipeline, background, use_trained_exp=train_test_exp)["render"]
|
||||||
gt = view.original_image[0:3, :, :]
|
gt = view.original_image[0:3, :, :]
|
||||||
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
||||||
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
||||||
@ -43,10 +43,10 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
|
|||||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
if not skip_train:
|
if not skip_train:
|
||||||
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
|
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, dataset.train_test_exp)
|
||||||
|
|
||||||
if not skip_test:
|
if not skip_test:
|
||||||
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
|
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, dataset.train_test_exp)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Set up command line argument parser
|
# Set up command line argument parser
|
||||||
|
@ -78,7 +78,7 @@ class Scene:
|
|||||||
self.gaussians.load_ply(os.path.join(self.model_path,
|
self.gaussians.load_ply(os.path.join(self.model_path,
|
||||||
"point_cloud",
|
"point_cloud",
|
||||||
"iteration_" + str(self.loaded_iter),
|
"iteration_" + str(self.loaded_iter),
|
||||||
"point_cloud.ply"))
|
"point_cloud.ply"), args.train_test_exp)
|
||||||
else:
|
else:
|
||||||
self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)
|
self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import numpy as np
|
|||||||
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from utils.system_utils import mkdir_p
|
from utils.system_utils import mkdir_p
|
||||||
from plyfile import PlyData, PlyElement
|
from plyfile import PlyData, PlyElement
|
||||||
from utils.sh_utils import RGB2SH
|
from utils.sh_utils import RGB2SH
|
||||||
@ -127,7 +128,10 @@ class GaussianModel:
|
|||||||
return self._exposure
|
return self._exposure
|
||||||
|
|
||||||
def get_exposure_from_name(self, image_name):
|
def get_exposure_from_name(self, image_name):
|
||||||
return self._exposure[self.exposure_mapping[image_name]]
|
if self.pretrained_exposures is None:
|
||||||
|
return self._exposure[self.exposure_mapping[image_name]]
|
||||||
|
else:
|
||||||
|
return self.pretrained_exposures[image_name]
|
||||||
|
|
||||||
def get_covariance(self, scaling_modifier = 1):
|
def get_covariance(self, scaling_modifier = 1):
|
||||||
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
||||||
@ -161,7 +165,7 @@ class GaussianModel:
|
|||||||
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
||||||
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
||||||
self.exposure_mapping = {cam_info.image_name: idx for idx, cam_info in enumerate(cam_infos)}
|
self.exposure_mapping = {cam_info.image_name: idx for idx, cam_info in enumerate(cam_infos)}
|
||||||
|
self.pretrained_exposures = None
|
||||||
exposure = torch.eye(3, 4, device="cuda")[None].repeat(len(cam_infos), 1, 1)
|
exposure = torch.eye(3, 4, device="cuda")[None].repeat(len(cam_infos), 1, 1)
|
||||||
self._exposure = nn.Parameter(exposure.requires_grad_(True))
|
self._exposure = nn.Parameter(exposure.requires_grad_(True))
|
||||||
|
|
||||||
@ -180,8 +184,8 @@ class GaussianModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
||||||
|
if self.pretrained_exposures is None:
|
||||||
self.exposure_optimizer = torch.optim.Adam([self._exposure])
|
self.exposure_optimizer = torch.optim.Adam([self._exposure])
|
||||||
|
|
||||||
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
||||||
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
||||||
@ -195,8 +199,9 @@ class GaussianModel:
|
|||||||
|
|
||||||
def update_learning_rate(self, iteration):
|
def update_learning_rate(self, iteration):
|
||||||
''' Learning rate scheduling per step '''
|
''' Learning rate scheduling per step '''
|
||||||
for param_group in self.exposure_optimizer.param_groups:
|
if self.pretrained_exposures is None:
|
||||||
param_group['lr'] = self.exposure_scheduler_args(iteration)
|
for param_group in self.exposure_optimizer.param_groups:
|
||||||
|
param_group['lr'] = self.exposure_scheduler_args(iteration)
|
||||||
|
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self.optimizer.param_groups:
|
||||||
if param_group["name"] == "xyz":
|
if param_group["name"] == "xyz":
|
||||||
@ -242,8 +247,18 @@ class GaussianModel:
|
|||||||
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
||||||
self._opacity = optimizable_tensors["opacity"]
|
self._opacity = optimizable_tensors["opacity"]
|
||||||
|
|
||||||
def load_ply(self, path):
|
def load_ply(self, path, use_train_test_exp = False):
|
||||||
plydata = PlyData.read(path)
|
plydata = PlyData.read(path)
|
||||||
|
if use_train_test_exp:
|
||||||
|
exposure_file = os.path.join(os.path.dirname(path), os.pardir, os.pardir, "exposure.json")
|
||||||
|
if os.path.exists(exposure_file):
|
||||||
|
with open(exposure_file, "r") as f:
|
||||||
|
exposures = json.load(f)
|
||||||
|
self.pretrained_exposures = {image_name: torch.FloatTensor(exposures[image_name]).requires_grad_(False).cuda() for image_name in exposures}
|
||||||
|
print(f"Pretrained exposures loaded.")
|
||||||
|
else:
|
||||||
|
print(f"No exposure to be loaded at {exposure_file}")
|
||||||
|
self.pretrained_exposures = None
|
||||||
|
|
||||||
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
||||||
np.asarray(plydata.elements[0]["y"]),
|
np.asarray(plydata.elements[0]["y"]),
|
||||||
|
Loading…
Reference in New Issue
Block a user