new release

This commit is contained in:
alanvinx 2024-08-21 14:30:43 +02:00
parent 472689c0dc
commit 21301643a4
21 changed files with 436 additions and 73 deletions

3
.gitmodules vendored
View File

@ -3,7 +3,8 @@
url = https://gitlab.inria.fr/bkerbl/simple-knn.git url = https://gitlab.inria.fr/bkerbl/simple-knn.git
[submodule "submodules/diff-gaussian-rasterization"] [submodule "submodules/diff-gaussian-rasterization"]
path = submodules/diff-gaussian-rasterization path = submodules/diff-gaussian-rasterization
url = https://github.com/graphdeco-inria/diff-gaussian-rasterization url = https://github.com/graphdeco-inria/diff-gaussian-rasterization.git
branch = dr_aa
[submodule "SIBR_viewers"] [submodule "SIBR_viewers"]
path = SIBR_viewers path = SIBR_viewers
url = https://gitlab.inria.fr/sibr/sibr_core.git url = https://gitlab.inria.fr/sibr/sibr_core.git

View File

@ -30,10 +30,21 @@ Abstract: *Radiance Field methods have recently revolutionized novel-view synthe
</section> </section>
## Funding and Acknowledgments ## Funding and Acknowledgments
This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The authors are grateful to Adobe for generous donations, the OPAL infrastructure from Université Côte dAzur and for the HPC resources from GENCIIDRIS (Grant 2022-AD011013409). The authors thank the anonymous reviewers for their valuable feedback, P. Hedman and A. Tewari for proofreading earlier drafts also T. Müller, A. Yu and S. Fridovich-Keil for helping with the comparisons. This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The authors are grateful to Adobe for generous donations, the OPAL infrastructure from Université Côte dAzur and for the HPC resources from GENCIIDRIS (Grant 2022-AD011013409). The authors thank the anonymous reviewers for their valuable feedback, P. Hedman and A. Tewari for proofreading earlier drafts also T. Müller, A. Yu and S. Fridovich-Keil for helping with the comparisons.
## NEW FEATURES !
We have limited resources for maintaining and updating the code. However, we have added a few new features since the original release that are inspired by some of the excellent work many other researchers have been doing on 3DGS. We will be adding other features within the ability of our resources.
Update of August 2024:
We have added/corrected the following features: [Depth regularization](#depth-regularization) for training, [anti aliasing](#anti-aliasing) and [exposure compensation](#exposure-compensation). We have enhanced the SIBR real time viewer by correcting bugs and adding features in the [Top View](#sibr:-top-view) that allows visualization of input and user cameras. Please note that it is currently not possible to use depth regularization with the training speed acceleration since they use different rasterizer versions.
Update of Spring 2024:
Orange Labs has kindly added [OpenXR support](#openXR-support) for VR viewing.
## Step-by-step Tutorial ## Step-by-step Tutorial
Jonathan Stephens made a fantastic step-by-step tutorial for setting up Gaussian Splatting on your machine, along with instructions for creating usable datasets from videos. If the instructions below are too dry for you, go ahead and check it out [here](https://www.youtube.com/watch?v=UXtuigy_wYc). Jonathan Stephens made a fantastic step-by-step tutorial for setting up Gaussian Splatting on your machine, along with instructions for creating usable datasets from videos. If the instructions below are too dry for you, go ahead and check it out [here](https://www.youtube.com/watch?v=UXtuigy_wYc).
@ -65,9 +76,8 @@ The codebase has 4 main components:
The components have different requirements w.r.t. both hardware and software. They have been tested on Windows 10 and Ubuntu Linux 22.04. Instructions for setting up and running each of them are found in the sections below. The components have different requirements w.r.t. both hardware and software. They have been tested on Windows 10 and Ubuntu Linux 22.04. Instructions for setting up and running each of them are found in the sections below.
## New features [Please check regularly!]
We will be adding several new features soon. In the meantime Orange has kindly added [OpenXR support](#openXR-support) for VR viewing. Please come back soon, we will be adding other features, building among others on recent 3DGS followup papers.
## Optimizer ## Optimizer
@ -482,11 +492,62 @@ python convert.py -s <location> --skip_matching [--resize] #If not resizing, Ima
</details> </details>
<br> <br>
### Depth regularization
Two preprocessing steps are required to enable depth regularization when training a scene:
To have better reconstructed scenes we use depth maps as priors during optimization with each input images. It works best on untextured parts ex: roads and can remove floaters. Several papers have used similar ideas to improve various aspects of 3DGS; (e.g. [DepthRegularizedGS](https://robot0321.github.io/DepthRegGS/index.html), [SparseGS](https://formycat.github.io/SparseGS-Real-Time-360-Sparse-View-Synthesis-using-Gaussian-Splatting/), [DNGaussian](https://fictionarry.github.io/DNGaussian/)). The depth regularization we integrated is that used in our [Hierarchical 3DGS](https://repo-sam.inria.fr/fungraph/hierarchical-3d-gaussians/) paper, but applied to the original 3DGS; for some scenes (e.g., the DeepBlending scenes) it improves quality significantly; for others it either makes a small difference or can even be worse. For details statistics please see here: [Stats for depth regularization](results.md).
1. Depth maps should be generated for each input images, to this effect we suggest using [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2?tab=readme-ov-file#usage).
2. Generate a `depth_params.json` file using:
```
python utils/make_depth_scale.py --base_dir <path to colmap> --depths_dir <path to generated depths>
```
A new parameter should be set when training if you want to use depth regularization `-d <path to depth maps>`.
### Exposure compensation
To compensate for exposure changes in the different input images we optimize an affine transformation for each image just as in [Hierarchical 3dgs](https://repo-sam.inria.fr/fungraph/hierarchical-3d-gaussians/). Add the following parameters to enable it:
```
--exposure_lr_init 0.001 --exposure_lr_final 0.0001 --exposure_lr_delay_steps 5000 --exposure_lr_delay_mult 0.001 --train_test_exp
```
Again, other excellent papers have used similar ideas e.g. [NeRF-W](https://nerf-w.github.io/), [URF](https://urban-radiance-fields.github.io/).
### Anti aliasing
We added the EWA Filter from [Mip Splatting](https://niujinshuchong.github.io/mip-splatting/) in our codebase to remove aliasing. Antialiasing is enabled by default, to disable it please do the following:
1. Comment out `#define DGR_FIX_AA` in `submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h`.
2. Re-install the rasterizer in your environment:
```
pip uninstall diff-gaussian-rasterization
cd submodules/diff-gaussian-rasterization
rm -r build
pip install .
```
### SIBR: Top view
> `Views > Top view`
The `Top view` renders the SfM point cloud in another view with the corresponding input cameras and the `Point view` user camera. This allows visualization of how far the viewer is from the input cameras for example.
It is a 3D view so the user can navigate through it just as in the `Point view` (modes available: FPS, trackball, orbit).
<!-- _gif showing the top view, showing it is realtime_ -->
<!-- ![topViewOpen_1.gif](../assets/topViewOpen_1_1709560483017_0.gif) -->
![top view open](assets/top_view_open.gif)
Options are available to customize this view, meshes can be disabled/enabled and their scales can be modified.
<!-- _gif showing different options_ -->
<!-- ![topViewOptions.gif](../assets/topViewOptions_1709560615266_0.gif) -->
![top view options](assets/top_view_options.gif)
A useful additional functionality is to move to the position of an input image, and progressively fade out to the SfM point view in that position (e.g., to verify camera alignment). Views from input cameras can be displayed in the `Top view` (*note that `--images-path` must be set in the command line*). One can snap the `Top view` camera to the closest input camera from the user camera in the `Point view` by clicking `Top view settings > Cameras > Snap to closest`.
<!-- _gif showing for a snapped camera the ground truth image with alpha_ -->
<!-- ![topViewImageAlpha.gif](../assets/topViewImageAlpha_1709560852268_0.gif) -->
![top view image alpha](assets/top_view_image_alpha.gif)
### OpenXR support ### OpenXR support
OpenXR is supported in the branch gaussian_code_release_openxr OpenXR is supported in the branch gaussian_code_release_openxr
Within that branch, you can find documentation for VR support [here](https://gitlab.inria.fr/sibr/sibr_core/-/tree/gaussian_code_release_openxr?ref_type=heads). Within that branch, you can find documentation for VR support [here](https://gitlab.inria.fr/sibr/sibr_core/-/tree/gaussian_code_release_openxr?ref_type=heads).
## FAQ ## FAQ
- *Where do I get data sets, e.g., those referenced in ```full_eval.py```?* The MipNeRF360 data set is provided by the authors of the original paper on the project site. Note that two of the data sets cannot be openly shared and require you to consult the authors directly. For Tanks&Temples and Deep Blending, please use the download links provided at the top of the page. Alternatively, you may access the cloned data (status: August 2023!) from [HuggingFace](https://huggingface.co/camenduru/gaussian-splatting) - *Where do I get data sets, e.g., those referenced in ```full_eval.py```?* The MipNeRF360 data set is provided by the authors of the original paper on the project site. Note that two of the data sets cannot be openly shared and require you to consult the authors directly. For Tanks&Temples and Deep Blending, please use the download links provided at the top of the page. Alternatively, you may access the cloned data (status: August 2023!) from [HuggingFace](https://huggingface.co/camenduru/gaussian-splatting)

@ -1 +1 @@
Subproject commit 4ae964a267cd7a844d9766563cf9d0b500131a22 Subproject commit 212f9fbb59d88db0022cd1326a8458759bf310ad

View File

@ -50,8 +50,10 @@ class ModelParams(ParamGroup):
self._source_path = "" self._source_path = ""
self._model_path = "" self._model_path = ""
self._images = "images" self._images = "images"
self._depths = ""
self._resolution = -1 self._resolution = -1
self._white_background = False self._white_background = False
self.train_test_exp = False
self.data_device = "cuda" self.data_device = "cuda"
self.eval = False self.eval = False
super().__init__(parser, "Loading Parameters", sentinel) super().__init__(parser, "Loading Parameters", sentinel)
@ -76,9 +78,13 @@ class OptimizationParams(ParamGroup):
self.position_lr_delay_mult = 0.01 self.position_lr_delay_mult = 0.01
self.position_lr_max_steps = 30_000 self.position_lr_max_steps = 30_000
self.feature_lr = 0.0025 self.feature_lr = 0.0025
self.opacity_lr = 0.05 self.opacity_lr = 0.025
self.scaling_lr = 0.005 self.scaling_lr = 0.005
self.rotation_lr = 0.001 self.rotation_lr = 0.001
self.exposure_lr_init = 0.01
self.exposure_lr_final = 0.001
self.exposure_lr_delay_steps = 0
self.exposure_lr_delay_mult = 0.0
self.percent_dense = 0.01 self.percent_dense = 0.01
self.lambda_dssim = 0.2 self.lambda_dssim = 0.2
self.densification_interval = 100 self.densification_interval = 100
@ -86,6 +92,8 @@ class OptimizationParams(ParamGroup):
self.densify_from_iter = 500 self.densify_from_iter = 500
self.densify_until_iter = 15_000 self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002 self.densify_grad_threshold = 0.0002
self.depth_l1_weight_init = 1.0
self.depth_l1_weight_final = 0.01
self.random_background = False self.random_background = False
super().__init__(parser, "Optimization Parameters") super().__init__(parser, "Optimization Parameters")

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

BIN
assets/all_results_PSNR.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 87 KiB

BIN
assets/all_results_SSIM.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 MiB

BIN
assets/top_view_open.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 MiB

BIN
assets/top_view_options.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

View File

@ -11,6 +11,7 @@
import os import os
from argparse import ArgumentParser from argparse import ArgumentParser
import time
mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
@ -38,18 +39,30 @@ if not args.skip_training or not args.skip_rendering:
if not args.skip_training: if not args.skip_training:
common_args = " --quiet --eval --test_iterations -1 " common_args = " --quiet --eval --test_iterations -1 "
start_time = time.time()
for scene in mipnerf360_outdoor_scenes: for scene in mipnerf360_outdoor_scenes:
source = args.mipnerf360 + "/" + scene source = args.mipnerf360 + "/" + scene
os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
for scene in mipnerf360_indoor_scenes: for scene in mipnerf360_indoor_scenes:
source = args.mipnerf360 + "/" + scene source = args.mipnerf360 + "/" + scene
os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
m360_timing = (time.time() - start_time)/60.0
start_time = time.time()
for scene in tanks_and_temples_scenes: for scene in tanks_and_temples_scenes:
source = args.tanksandtemples + "/" + scene source = args.tanksandtemples + "/" + scene
os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
tandt_timing = (time.time() - start_time)/60.0
start_time = time.time()
for scene in deep_blending_scenes: for scene in deep_blending_scenes:
source = args.deepblending + "/" + scene source = args.deepblending + "/" + scene
os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
db_timing = (time.time() - start_time)/60.0
with open("timing.txt", 'w') as file:
file.write(f"m360: {m360_timing} minutes \n tandt: {tandt_timing} minutes \n db: {db_timing} minutes\n")
if not args.skip_rendering: if not args.skip_rendering:
all_sources = [] all_sources = []

View File

@ -15,7 +15,7 @@ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianR
from scene.gaussian_model import GaussianModel from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh from utils.sh_utils import eval_sh
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, use_trained_exp=False):
""" """
Render the scene. Render the scene.
@ -59,6 +59,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
scales = None scales = None
rotations = None rotations = None
cov3D_precomp = None cov3D_precomp = None
if pipe.compute_cov3D_python: if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier) cov3D_precomp = pc.get_covariance(scaling_modifier)
else: else:
@ -82,7 +83,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
colors_precomp = override_color colors_precomp = override_color
# Rasterize visible Gaussians to image, obtain their radii (on screen). # Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer( rendered_image, radii, depth_image = rasterizer(
means3D = means3D, means3D = means3D,
means2D = means2D, means2D = means2D,
shs = shs, shs = shs,
@ -91,10 +92,21 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
scales = scales, scales = scales,
rotations = rotations, rotations = rotations,
cov3D_precomp = cov3D_precomp) cov3D_precomp = cov3D_precomp)
# Apply exposure to rendered image (training only)
if use_trained_exp:
exposure = pc.get_exposure_from_name(viewpoint_camera.image_name)
rendered_image = torch.matmul(rendered_image.permute(1, 2, 0), exposure[:3, :3]).permute(2, 0, 1) + exposure[:3, 3, None, None]
# Those Gaussians that were frustum culled or had a radius of 0 were not visible. # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria. # They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image, rendered_image = rendered_image.clamp(0, 1)
"viewspace_points": screenspace_points, out = {
"visibility_filter" : radii > 0, "render": rendered_image,
"radii": radii} "viewspace_points": screenspace_points,
"visibility_filter" : (radii > 0).nonzero(),
"radii": radii,
"depth" : depth_image
}
return out

14
results.md Normal file
View File

@ -0,0 +1,14 @@
# Evaluations
We evaluated the impact of the features we added on MipNeRF360, Tanks&Temples and Deep Blending datasets.
## PSNR
![all results PSNR](assets/all_results_PSNR.png)
***DR**:depth regularization, **AA**:antialiasing, **EXPCOMP**:exposure compensation.*
## SSIM
![all results SSIM](assets/all_results_SSIM.png)
***DR**:depth regularization, **AA**:antialiasing, **EXPCOMP**:exposure compensation.*
## LPIPS
![all results LPIPS](assets/all_results_LPIPS.png)
*lower is better, **DR**:depth regularization, **AA**:antialiasing, **EXPCOMP**:exposure compensation.*

View File

@ -41,7 +41,7 @@ class Scene:
self.test_cameras = {} self.test_cameras = {}
if os.path.exists(os.path.join(args.source_path, "sparse")): if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!") print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
@ -70,9 +70,9 @@ class Scene:
for resolution_scale in resolution_scales: for resolution_scale in resolution_scales:
print("Loading Training Cameras") print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args) self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, False)
print("Loading Test Cameras") print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args) self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, True)
if self.loaded_iter: if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path, self.gaussians.load_ply(os.path.join(self.model_path,
@ -80,14 +80,21 @@ class Scene:
"iteration_" + str(self.loaded_iter), "iteration_" + str(self.loaded_iter),
"point_cloud.ply")) "point_cloud.ply"))
else: else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)
def save(self, iteration): def save(self, iteration):
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
exposure_dict = {
image_name: self.gaussians.get_exposure_from_name(image_name).detach().cpu().numpy().tolist()
for image_name in self.gaussians.exposure_mapping
}
with open(os.path.join(self.model_path, "exposure.json"), "w") as f:
json.dump(exposure_dict, f, indent=2)
def getTrainCameras(self, scale=1.0): def getTrainCameras(self, scale=1.0):
return self.train_cameras[scale] return self.train_cameras[scale]
def getTestCameras(self, scale=1.0): def getTestCameras(self, scale=1.0):
return self.test_cameras[scale] return self.test_cameras[scale]

View File

@ -13,11 +13,14 @@ import torch
from torch import nn from torch import nn
import numpy as np import numpy as np
from utils.graphics_utils import getWorld2View2, getProjectionMatrix from utils.graphics_utils import getWorld2View2, getProjectionMatrix
from utils.general_utils import PILtoTorch
import cv2
class Camera(nn.Module): class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, image, invdepthmap,
image_name, uid, image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda",
train_test_exp = False, is_test_dataset = False, is_test_view = False
): ):
super(Camera, self).__init__() super(Camera, self).__init__()
@ -36,14 +39,43 @@ class Camera(nn.Module):
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda") self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_device) resized_image_rgb = PILtoTorch(image, resolution)
gt_image = resized_image_rgb[:3, ...]
self.alpha_mask = None
if resized_image_rgb.shape[0] == 4:
self.alpha_mask = resized_image_rgb[3:4, ...].to(self.data_device)
else:
self.alpha_mask = torch.ones_like(resized_image_rgb[0:1, ...].to(self.data_device))
if train_test_exp and is_test_view:
if is_test_dataset:
self.alpha_mask[..., :self.alpha_mask.shape[-1] // 2] = 0
else:
self.alpha_mask[..., self.alpha_mask.shape[-1] // 2:] = 0
self.original_image = gt_image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2] self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1] self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None: self.invdepthmap = None
self.original_image *= gt_alpha_mask.to(self.data_device) self.depth_reliable = False
else: if invdepthmap is not None and depth_params is not None and depth_params["scale"] > 0:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) invdepthmapScaled = invdepthmap * depth_params["scale"] + depth_params["offset"]
invdepthmapScaled = cv2.resize(invdepthmapScaled, resolution)
invdepthmapScaled[invdepthmapScaled < 0] = 0
if invdepthmapScaled.ndim != 2:
invdepthmapScaled = invdepthmapScaled[..., 0]
self.invdepthmap = torch.from_numpy(invdepthmapScaled[None]).to(self.data_device)
if self.alpha_mask is not None:
self.depth_mask = self.alpha_mask.clone()
else:
self.depth_mask = torch.ones_like(self.invdepthmap > 0)
if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]:
self.depth_mask *= 0
else:
self.depth_reliable = True
self.zfar = 100.0 self.zfar = 100.0
self.znear = 0.01 self.znear = 0.01
@ -55,7 +87,7 @@ class Camera(nn.Module):
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3] self.camera_center = self.world_view_transform.inverse()[3, :3]
class MiniCam: class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width self.image_width = width

View File

@ -29,11 +29,13 @@ class CameraInfo(NamedTuple):
T: np.array T: np.array
FovY: np.array FovY: np.array
FovX: np.array FovX: np.array
image: np.array depth_params: dict
image_path: str image_path: str
image_name: str image_name: str
depth_path: str
width: int width: int
height: int height: int
is_test: bool
class SceneInfo(NamedTuple): class SceneInfo(NamedTuple):
point_cloud: BasicPointCloud point_cloud: BasicPointCloud
@ -65,7 +67,7 @@ def getNerfppNorm(cam_info):
return {"translate": translate, "radius": radius} return {"translate": translate, "radius": radius}
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, depths_folder, test_cam_names_list):
cam_infos = [] cam_infos = []
for idx, key in enumerate(cam_extrinsics): for idx, key in enumerate(cam_extrinsics):
sys.stdout.write('\r') sys.stdout.write('\r')
@ -94,13 +96,23 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
else: else:
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
image_path = os.path.join(images_folder, os.path.basename(extr.name)) n_remove = len(extr.name.split('.')[-1]) + 1
image_name = os.path.basename(image_path).split(".")[0] depth_params = None
image = Image.open(image_path) if depths_params is not None:
try:
depth_params = depths_params[extr.name[:-n_remove]]
except:
print("\n", key, "not found in depths_params")
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, image_path = os.path.join(images_folder, extr.name)
image_path=image_path, image_name=image_name, width=width, height=height) image_name = extr.name
depth_path = os.path.join(depths_folder, f"{extr.name[:-n_remove]}.png") if depths_folder != "" else ""
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, depth_params=depth_params,
image_path=image_path, image_name=image_name, depth_path=depth_path,
width=width, height=height, is_test=image_name in test_cam_names_list)
cam_infos.append(cam_info) cam_infos.append(cam_info)
sys.stdout.write('\n') sys.stdout.write('\n')
return cam_infos return cam_infos
@ -129,7 +141,7 @@ def storePly(path, xyz, rgb):
ply_data = PlyData([vertex_element]) ply_data = PlyData([vertex_element])
ply_data.write(path) ply_data.write(path)
def readColmapSceneInfo(path, images, eval, llffhold=8): def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
try: try:
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
@ -141,16 +153,51 @@ def readColmapSceneInfo(path, images, eval, llffhold=8):
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
reading_dir = "images" if images == None else images depth_params_file = os.path.join(path, "sparse/0", "depth_params.json")
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) ## if depth_params_file isnt there AND depths file is here -> throw error
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) depths_params = None
if depths != "":
try:
with open(depth_params_file, "r") as f:
depths_params = json.load(f)
all_scales = np.array([depths_params[key]["scale"] for key in depths_params])
if (all_scales > 0).sum():
med_scale = np.median(all_scales[all_scales > 0])
else:
med_scale = 0
for key in depths_params:
depths_params[key]["med_scale"] = med_scale
except FileNotFoundError:
print(f"Error: depth_params.json file not found at path '{depth_params_file}'.")
sys.exit(1)
except Exception as e:
print(f"An unexpected error occurred when trying to open depth_params.json file: {e}")
sys.exit(1)
if eval: if eval:
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] if "360" in path:
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] llffhold = 8
if llffhold:
print("------------LLFF HOLD-------------")
cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
cam_names = sorted(cam_names)
test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
else:
with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
test_cam_names_list = [line.strip() for line in file]
else: else:
train_cam_infos = cam_infos test_cam_names_list = []
test_cam_infos = []
reading_dir = "images" if images == None else images
cam_infos_unsorted = readColmapCameras(
cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, depths_params=depths_params,
images_folder=os.path.join(path, reading_dir),
depths_folder=os.path.join(path, depths) if depths != "" else "", test_cam_names_list=test_cam_names_list)
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
train_cam_infos = [c for c in cam_infos if train_test_exp or not c.is_test]
test_cam_infos = [c for c in cam_infos if c.is_test]
nerf_normalization = getNerfppNorm(train_cam_infos) nerf_normalization = getNerfppNorm(train_cam_infos)
@ -213,7 +260,7 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
FovY = fovy FovY = fovy
FovX = fovx FovX = fovx
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
return cam_infos return cam_infos

View File

@ -41,7 +41,7 @@ class GaussianModel:
self.rotation_activation = torch.nn.functional.normalize self.rotation_activation = torch.nn.functional.normalize
def __init__(self, sh_degree : int): def __init__(self, sh_degree):
self.active_sh_degree = 0 self.active_sh_degree = 0
self.max_sh_degree = sh_degree self.max_sh_degree = sh_degree
self._xyz = torch.empty(0) self._xyz = torch.empty(0)
@ -110,10 +110,25 @@ class GaussianModel:
features_rest = self._features_rest features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1) return torch.cat((features_dc, features_rest), dim=1)
@property
def get_features_dc(self):
return self._features_dc
@property
def get_features_rest(self):
return self._features_rest
@property @property
def get_opacity(self): def get_opacity(self):
return self.opacity_activation(self._opacity) return self.opacity_activation(self._opacity)
@property
def get_exposure(self):
return self._exposure
def get_exposure_from_name(self, image_name):
return self._exposure[self.exposure_mapping[image_name]]
def get_covariance(self, scaling_modifier = 1): def get_covariance(self, scaling_modifier = 1):
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
@ -121,7 +136,7 @@ class GaussianModel:
if self.active_sh_degree < self.max_sh_degree: if self.active_sh_degree < self.max_sh_degree:
self.active_sh_degree += 1 self.active_sh_degree += 1
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): def create_from_pcd(self, pcd : BasicPointCloud, cam_infos : int, spatial_lr_scale : float):
self.spatial_lr_scale = spatial_lr_scale self.spatial_lr_scale = spatial_lr_scale
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
@ -136,7 +151,7 @@ class GaussianModel:
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1 rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) opacities = self.inverse_opacity_activation(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
@ -145,6 +160,10 @@ class GaussianModel:
self._rotation = nn.Parameter(rots.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self.exposure_mapping = {cam_info.image_name: idx for idx, cam_info in enumerate(cam_infos)}
exposure = torch.eye(3, 4, device="cuda")[None].repeat(len(cam_infos), 1, 1)
self._exposure = nn.Parameter(exposure.requires_grad_(True))
def training_setup(self, training_args): def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense self.percent_dense = training_args.percent_dense
@ -161,13 +180,24 @@ class GaussianModel:
] ]
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
self.exposure_optimizer = torch.optim.Adam([self._exposure])
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
lr_final=training_args.position_lr_final*self.spatial_lr_scale, lr_final=training_args.position_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.position_lr_delay_mult, lr_delay_mult=training_args.position_lr_delay_mult,
max_steps=training_args.position_lr_max_steps) max_steps=training_args.position_lr_max_steps)
self.exposure_scheduler_args = get_expon_lr_func(training_args.exposure_lr_init, training_args.exposure_lr_final,
lr_delay_steps=training_args.exposure_lr_delay_steps,
lr_delay_mult=training_args.exposure_lr_delay_mult,
max_steps=training_args.iterations)
def update_learning_rate(self, iteration): def update_learning_rate(self, iteration):
''' Learning rate scheduling per step ''' ''' Learning rate scheduling per step '''
for param_group in self.exposure_optimizer.param_groups:
param_group['lr'] = self.exposure_scheduler_args(iteration)
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
if param_group["name"] == "xyz": if param_group["name"] == "xyz":
lr = self.xyz_scheduler_args(iteration) lr = self.xyz_scheduler_args(iteration)
@ -208,7 +238,7 @@ class GaussianModel:
PlyData([el]).write(path) PlyData([el]).write(path)
def reset_opacity(self): def reset_opacity(self):
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
self._opacity = optimizable_tensors["opacity"] self._opacity = optimizable_tensors["opacity"]

@ -1 +1 @@
Subproject commit 59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d Subproject commit eb015708bb0ae0468367b60ec8d809a5f1ec34fe

View File

@ -16,7 +16,7 @@ from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render, network_gui from gaussian_renderer import render, network_gui
import sys import sys
from scene import Scene, GaussianModel from scene import Scene, GaussianModel
from utils.general_utils import safe_state from utils.general_utils import safe_state, get_expon_lr_func
import uuid import uuid
from tqdm import tqdm from tqdm import tqdm
from utils.image_utils import psnr from utils.image_utils import psnr
@ -44,11 +44,16 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
iter_start = torch.cuda.Event(enable_timing = True) iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True) iter_end = torch.cuda.Event(enable_timing = True)
viewpoint_stack = None depth_l1_weight = get_expon_lr_func(opt.depth_l1_weight_init, opt.depth_l1_weight_final, max_steps=opt.iterations)
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_indices = list(range(len(viewpoint_stack)))
ema_loss_for_log = 0.0 ema_loss_for_log = 0.0
ema_Ll1depth_for_log = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1 first_iter += 1
for iteration in range(first_iter, opt.iterations + 1): for iteration in range(first_iter, opt.iterations + 1):
if network_gui.conn == None: if network_gui.conn == None:
network_gui.try_connect() network_gui.try_connect()
while network_gui.conn != None: while network_gui.conn != None:
@ -75,7 +80,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
# Pick a random Camera # Pick a random Camera
if not viewpoint_stack: if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy() viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) viewpoint_indices = list(range(len(viewpoint_stack)))
rand_idx = randint(0, len(viewpoint_indices) - 1)
viewpoint_cam = viewpoint_stack.pop(rand_idx)
vind = viewpoint_indices.pop(rand_idx)
# Render # Render
if (iteration - 1) == debug_from: if (iteration - 1) == debug_from:
@ -83,13 +91,29 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
bg = torch.rand((3), device="cuda") if opt.random_background else background bg = torch.rand((3), device="cuda") if opt.random_background else background
render_pkg = render(viewpoint_cam, gaussians, pipe, bg) render_pkg = render(viewpoint_cam, gaussians, pipe, bg, use_trained_exp=dataset.train_test_exp)
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# Loss # Loss
gt_image = viewpoint_cam.original_image.cuda() gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image) Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)) ssim_value = ssim(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
# Depth regularization
Ll1depth_pure = 0.0
if depth_l1_weight(iteration) > 0 and viewpoint_cam.depth_reliable:
invDepth = render_pkg["depth"]
mono_invdepth = viewpoint_cam.invdepthmap.cuda()
depth_mask = viewpoint_cam.depth_mask.cuda()
Ll1depth_pure = torch.abs((invDepth - mono_invdepth) * depth_mask).mean()
Ll1depth = depth_l1_weight(iteration) * Ll1depth_pure
loss += Ll1depth
Ll1depth = Ll1depth.item()
else:
Ll1depth = 0
loss.backward() loss.backward()
iter_end.record() iter_end.record()
@ -97,14 +121,16 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
with torch.no_grad(): with torch.no_grad():
# Progress bar # Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
ema_Ll1depth_for_log = 0.4 * Ll1depth + 0.6 * ema_Ll1depth_for_log
if iteration % 10 == 0: if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "Depth Loss": f"{ema_Ll1depth_for_log:.{7}f}"})
progress_bar.update(10) progress_bar.update(10)
if iteration == opt.iterations: if iteration == opt.iterations:
progress_bar.close() progress_bar.close()
# Log and save # Log and save
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background), dataset.train_test_exp)
if (iteration in saving_iterations): if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration)) print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration) scene.save(iteration)
@ -124,6 +150,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
# Optimizer step # Optimizer step
if iteration < opt.iterations: if iteration < opt.iterations:
gaussians.exposure_optimizer.step()
gaussians.exposure_optimizer.zero_grad(set_to_none = True)
gaussians.optimizer.step() gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True) gaussians.optimizer.zero_grad(set_to_none = True)
@ -153,7 +181,7 @@ def prepare_output_and_logger(args):
print("Tensorboard not available: not logging progress") print("Tensorboard not available: not logging progress")
return tb_writer return tb_writer
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
if tb_writer: if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
@ -172,6 +200,9 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
for idx, viewpoint in enumerate(config['cameras']): for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0) image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0) gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
if train_test_exp:
image = image[..., image.shape[-1] // 2:]
gt_image = gt_image[..., gt_image.shape[-1] // 2:]
if tb_writer and (idx < 5): if tb_writer and (idx < 5):
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration) tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
if iteration == testing_iterations[0]: if iteration == testing_iterations[0]:
@ -203,6 +234,7 @@ if __name__ == "__main__":
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000]) parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--quiet", action="store_true") parser.add_argument("--quiet", action="store_true")
parser.add_argument('--disable_viewer', action='store_true', default=False)
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None) parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
@ -214,7 +246,8 @@ if __name__ == "__main__":
safe_state(args.quiet) safe_state(args.quiet)
# Start GUI server, configure and run training # Start GUI server, configure and run training
network_gui.init(args.ip, args.port) if not args.disable_viewer:
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly) torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)

View File

@ -11,14 +11,31 @@
from scene.cameras import Camera from scene.cameras import Camera
import numpy as np import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal from utils.graphics_utils import fov2focal
from PIL import Image
import cv2
WARNED = False WARNED = False
def loadCam(args, id, cam_info, resolution_scale): def loadCam(args, id, cam_info, resolution_scale, is_test_dataset):
orig_w, orig_h = cam_info.image.size image = Image.open(cam_info.image_path)
if cam_info.depth_path != "":
try:
invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16)
except FileNotFoundError:
print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.")
raise
except IOError:
print(f"Error: Unable to open the image file '{cam_info.depth_path}'. It may be corrupted or an unsupported format.")
raise
except Exception as e:
print(f"An unexpected error occurred when trying to read depth at {cam_info.depth_path}: {e}")
raise
else:
invdepthmap = None
orig_w, orig_h = image.size
if args.resolution in [1, 2, 4, 8]: if args.resolution in [1, 2, 4, 8]:
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
else: # should be a type that converts to float else: # should be a type that converts to float
@ -34,28 +51,22 @@ def loadCam(args, id, cam_info, resolution_scale):
global_down = 1 global_down = 1
else: else:
global_down = orig_w / args.resolution global_down = orig_w / args.resolution
scale = float(global_down) * float(resolution_scale) scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale)) resolution = (int(orig_w / scale), int(orig_h / scale))
resized_image_rgb = PILtoTorch(cam_info.image, resolution) return Camera(resolution, colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY, depth_params=cam_info.depth_params,
image=image, invdepthmap=invdepthmap,
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test)
gt_image = resized_image_rgb[:3, ...] def cameraList_from_camInfos(cam_infos, resolution_scale, args, is_test_dataset):
loaded_mask = None
if resized_image_rgb.shape[1] == 4:
loaded_mask = resized_image_rgb[3:4, ...]
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = [] camera_list = []
for id, c in enumerate(cam_infos): for id, c in enumerate(cam_infos):
camera_list.append(loadCam(args, id, c, resolution_scale)) camera_list.append(loadCam(args, id, c, resolution_scale, is_test_dataset))
return camera_list return camera_list
@ -79,4 +90,4 @@ def camera_to_JSON(id, camera : Camera):
'fy' : fov2focal(camera.FovY, camera.height), 'fy' : fov2focal(camera.FovY, camera.height),
'fx' : fov2focal(camera.FovX, camera.width) 'fx' : fov2focal(camera.FovX, camera.width)
} }
return camera_entry return camera_entry

94
utils/make_depth_scale.py Normal file
View File

@ -0,0 +1,94 @@
import numpy as np
import argparse
import cv2
from joblib import delayed, Parallel
import json
from read_write_model import *
def get_scales(key, cameras, images, points3d_ordered, args):
image_meta = images[key]
cam_intrinsic = cameras[image_meta.camera_id]
pts_idx = images_metas[key].point3D_ids
mask = pts_idx >= 0
mask *= pts_idx < len(points3d_ordered)
pts_idx = pts_idx[mask]
valid_xys = image_meta.xys[mask]
if len(pts_idx) > 0:
pts = points3d_ordered[pts_idx]
else:
pts = np.array([0, 0, 0])
R = qvec2rotmat(image_meta.qvec)
pts = np.dot(pts, R.T) + image_meta.tvec
invcolmapdepth = 1. / pts[..., 2]
n_remove = len(image_meta.name.split('.')[-1]) + 1
invmonodepthmap = cv2.imread(f"{args.depths_dir}/{image_meta.name[:-n_remove]}.png", cv2.IMREAD_UNCHANGED)
if invmonodepthmap is None:
return None
if invmonodepthmap.ndim != 2:
invmonodepthmap = invmonodepthmap[..., 0]
invmonodepthmap = invmonodepthmap.astype(np.float32) / (2**16)
s = invmonodepthmap.shape[0] / cam_intrinsic.height
maps = (valid_xys * s).astype(np.float32)
valid = (
(maps[..., 0] >= 0) *
(maps[..., 1] >= 0) *
(maps[..., 0] < cam_intrinsic.width * s) *
(maps[..., 1] < cam_intrinsic.height * s) * (invcolmapdepth > 0))
if valid.sum() > 10 and (invcolmapdepth.max() - invcolmapdepth.min()) > 1e-3:
maps = maps[valid, :]
invcolmapdepth = invcolmapdepth[valid]
invmonodepth = cv2.remap(invmonodepthmap, maps[..., 0], maps[..., 1], interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)[..., 0]
## Median / dev
t_colmap = np.median(invcolmapdepth)
s_colmap = np.mean(np.abs(invcolmapdepth - t_colmap))
t_mono = np.median(invmonodepth)
s_mono = np.mean(np.abs(invmonodepth - t_mono))
scale = s_colmap / s_mono
offset = t_colmap - t_mono * scale
else:
scale = 0
offset = 0
return {"image_name": image_meta.name[:-n_remove], "scale": scale, "offset": offset}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default="../data/big_gaussians/standalone_chunks/campus")
parser.add_argument('--depths_dir', default="../data/big_gaussians/standalone_chunks/campus/depths_any")
parser.add_argument('--model_type', default="bin")
args = parser.parse_args()
cam_intrinsics, images_metas, points3d = read_model(os.path.join(args.base_dir, "sparse", "0"), ext=f".{args.model_type}")
pts_indices = np.array([points3d[key].id for key in points3d])
pts_xyzs = np.array([points3d[key].xyz for key in points3d])
points3d_ordered = np.zeros([pts_indices.max()+1, 3])
points3d_ordered[pts_indices] = pts_xyzs
# depth_param_list = [get_scales(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas]
depth_param_list = Parallel(n_jobs=-1, backend="threading")(
delayed(get_scales)(key, cam_intrinsics, images_metas, points3d_ordered, args) for key in images_metas
)
depth_params = {
depth_param["image_name"]: {"scale": depth_param["scale"], "offset": depth_param["offset"]}
for depth_param in depth_param_list if depth_param != None
}
with open(f"{args.base_dir}/sparse/0/depth_params.json", "w") as f:
json.dump(depth_params, f, indent=2)
print(0)