added NeRF synthetic datasets training using depth regularization

This commit is contained in:
alanvinx 2024-09-17 10:15:46 +02:00
parent f689f41bd0
commit 9375aa7664
5 changed files with 49 additions and 34 deletions

View File

@ -40,10 +40,10 @@ This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The autho
We have limited resources for maintaining and updating the code. However, we have added a few new features since the original release that are inspired by some of the excellent work many other researchers have been doing on 3DGS. We will be adding other features within the ability of our resources. We have limited resources for maintaining and updating the code. However, we have added a few new features since the original release that are inspired by some of the excellent work many other researchers have been doing on 3DGS. We will be adding other features within the ability of our resources.
Update of August 2024: Update of August 2024:
We have added/corrected the following features: [Depth regularization](#depth-regularization) for training, [anti aliasing](#anti-aliasing) and [exposure compensation](#exposure-compensation). We have enhanced the SIBR real time viewer by correcting bugs and adding features in the [Top View](#sibr:-top-view) that allows visualization of input and user cameras. Please note that it is currently not possible to use depth regularization with the training speed acceleration since they use different rasterizer versions. We have added/corrected the following features: [Depth regularization](#depth-regularization) for training, [anti-aliasing](#anti-aliasing) and [exposure compensation](#exposure-compensation). We have enhanced the SIBR real time viewer by correcting bugs and adding features in the [Top View](#sibr-top-view) that allows visualization of input and user cameras. Please note that it is currently not possible to use depth regularization with the training speed acceleration since they use different rasterizer versions.
Update of Spring 2024: Update of Spring 2024:
Orange Labs has kindly added [OpenXR support](#openXR-support) for VR viewing. Orange Labs has kindly added [OpenXR support](#openxr-support) for VR viewing.
## Step-by-step Tutorial ## Step-by-step Tutorial
@ -497,7 +497,9 @@ python convert.py -s <location> --skip_matching [--resize] #If not resizing, Ima
Two preprocessing steps are required to enable depth regularization when training a scene: Two preprocessing steps are required to enable depth regularization when training a scene:
To have better reconstructed scenes we use depth maps as priors during optimization with each input images. It works best on untextured parts ex: roads and can remove floaters. Several papers have used similar ideas to improve various aspects of 3DGS; (e.g. [DepthRegularizedGS](https://robot0321.github.io/DepthRegGS/index.html), [SparseGS](https://formycat.github.io/SparseGS-Real-Time-360-Sparse-View-Synthesis-using-Gaussian-Splatting/), [DNGaussian](https://fictionarry.github.io/DNGaussian/)). The depth regularization we integrated is that used in our [Hierarchical 3DGS](https://repo-sam.inria.fr/fungraph/hierarchical-3d-gaussians/) paper, but applied to the original 3DGS; for some scenes (e.g., the DeepBlending scenes) it improves quality significantly; for others it either makes a small difference or can even be worse. For details statistics please see here: [Stats for depth regularization](results.md). To have better reconstructed scenes we use depth maps as priors during optimization with each input images. It works best on untextured parts ex: roads and can remove floaters. Several papers have used similar ideas to improve various aspects of 3DGS; (e.g. [DepthRegularizedGS](https://robot0321.github.io/DepthRegGS/index.html), [SparseGS](https://formycat.github.io/SparseGS-Real-Time-360-Sparse-View-Synthesis-using-Gaussian-Splatting/), [DNGaussian](https://fictionarry.github.io/DNGaussian/)). The depth regularization we integrated is that used in our [Hierarchical 3DGS](https://repo-sam.inria.fr/fungraph/hierarchical-3d-gaussians/) paper, but applied to the original 3DGS; for some scenes (e.g., the DeepBlending scenes) it improves quality significantly; for others it either makes a small difference or can even be worse. For details statistics please see here: [Stats for depth regularization](results.md).
1. Depth maps should be generated for each input images, to this effect we suggest using [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2?tab=readme-ov-file#usage).
When training on a synthetic dataset, depth maps can be produced and they do not require further processing to be used in our method. For real world datasets please do the following:
1. Get depth maps for each input images, to this effect we suggest using [Depth anything v2](https://github.com/DepthAnything/Depth-Anything-V2?tab=readme-ov-file#usage).
2. Generate a `depth_params.json` file using: 2. Generate a `depth_params.json` file using:
``` ```
python utils/make_depth_scale.py --base_dir <path to colmap> --depths_dir <path to generated depths> python utils/make_depth_scale.py --base_dir <path to colmap> --depths_dir <path to generated depths>
@ -512,9 +514,10 @@ To compensate for exposure changes in the different input images we optimize an
``` ```
Again, other excellent papers have used similar ideas e.g. [NeRF-W](https://nerf-w.github.io/), [URF](https://urban-radiance-fields.github.io/). Again, other excellent papers have used similar ideas e.g. [NeRF-W](https://nerf-w.github.io/), [URF](https://urban-radiance-fields.github.io/).
### Anti aliasing ### Anti-aliasing
We added the EWA Filter from [Mip Splatting](https://niujinshuchong.github.io/mip-splatting/) in our codebase to remove aliasing. It is disabled by default but you can enable it by adding `--antialiasing` when training on a scene using `train.py` or rendering using `render.py`. Antialiasing can be toggled in the SIBR viewer, it is disabled by default but you should enable it when viewing a scene trained using `--antialiasing`. We added the EWA Filter from [Mip Splatting](https://niujinshuchong.github.io/mip-splatting/) in our codebase to remove aliasing. It is disabled by default but you can enable it by adding `--antialiasing` when training on a scene using `train.py` or rendering using `render.py`. Antialiasing can be toggled in the SIBR viewer, it is disabled by default but you should enable it when viewing a scene trained using `--antialiasing`.
![aa](/assets/aa_onoff.gif) ![aa](/assets/aa_onoff.gif)
*this scene was trained using `--antialiasing`*.
### SIBR: Top view ### SIBR: Top view
> `Views > Top view` > `Views > Top view`

View File

@ -44,7 +44,7 @@ class Scene:
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp) scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!") print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.depths, args.eval)
else: else:
assert False, "Could not recognize scene type!" assert False, "Could not recognize scene type!"
@ -70,9 +70,9 @@ class Scene:
for resolution_scale in resolution_scales: for resolution_scale in resolution_scales:
print("Loading Training Cameras") print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, False) self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, False)
print("Loading Test Cameras") print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, True) self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, True)
if self.loaded_iter: if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path, self.gaussians.load_ply(os.path.join(self.model_path,

View File

@ -59,24 +59,24 @@ class Camera(nn.Module):
self.invdepthmap = None self.invdepthmap = None
self.depth_reliable = False self.depth_reliable = False
if invdepthmap is not None and depth_params is not None and depth_params["scale"] > 0: if invdepthmap is not None:
invdepthmapScaled = invdepthmap * depth_params["scale"] + depth_params["offset"] self.depth_mask = torch.ones_like(self.alpha_mask)
invdepthmapScaled = cv2.resize(invdepthmapScaled, resolution) self.invdepthmap = cv2.resize(invdepthmap, resolution)
invdepthmapScaled[invdepthmapScaled < 0] = 0 self.invdepthmap[self.invdepthmap < 0] = 0
if invdepthmapScaled.ndim != 2:
invdepthmapScaled = invdepthmapScaled[..., 0]
self.invdepthmap = torch.from_numpy(invdepthmapScaled[None]).to(self.data_device)
if self.alpha_mask is not None:
self.depth_mask = self.alpha_mask.clone()
else:
self.depth_mask = torch.ones_like(self.invdepthmap > 0)
if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]:
self.depth_mask *= 0
else:
self.depth_reliable = True self.depth_reliable = True
if depth_params is not None:
if depth_params["scale"] < 0.2 * depth_params["med_scale"] or depth_params["scale"] > 5 * depth_params["med_scale"]:
self.depth_reliable = False
self.depth_mask *= 0
if depth_params["scale"] > 0:
self.invdepthmap = self.invdepthmap * depth_params["scale"] + depth_params["offset"]
if self.invdepthmap.ndim != 2:
self.invdepthmap = self.invdepthmap[..., 0]
self.invdepthmap = torch.from_numpy(self.invdepthmap[None]).to(self.data_device)
self.zfar = 100.0 self.zfar = 100.0
self.znear = 0.01 self.znear = 0.01

View File

@ -43,6 +43,7 @@ class SceneInfo(NamedTuple):
test_cameras: list test_cameras: list
nerf_normalization: dict nerf_normalization: dict
ply_path: str ply_path: str
is_nerf_synthetic: bool
def getNerfppNorm(cam_info): def getNerfppNorm(cam_info):
def get_center_and_diag(cam_centers): def get_center_and_diag(cam_centers):
@ -220,10 +221,11 @@ def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
train_cameras=train_cam_infos, train_cameras=train_cam_infos,
test_cameras=test_cam_infos, test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization, nerf_normalization=nerf_normalization,
ply_path=ply_path) ply_path=ply_path,
is_nerf_synthetic=False)
return scene_info return scene_info
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): def readCamerasFromTransforms(path, transformsfile, depths_folder, white_background, is_test, extension=".png"):
cam_infos = [] cam_infos = []
with open(os.path.join(path, transformsfile)) as json_file: with open(os.path.join(path, transformsfile)) as json_file:
@ -260,16 +262,21 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
FovY = fovy FovY = fovy
FovX = fovx FovX = fovx
depth_path = os.path.join(depths_folder, f"{image_name}_depth_0002.png") if depths_folder != "" else ""
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) image_path=image_path, image_name=image_name,
width=image.size[0], height=image.size[1], depth_path=depth_path, depth_params=None, is_test=is_test))
return cam_infos return cam_infos
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): def readNerfSyntheticInfo(path, white_background, depths, eval, extension=".png"):
depths_folder=os.path.join(path, depths) if depths != "" else ""
print("Reading Training Transforms") print("Reading Training Transforms")
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension) train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", depths_folder, white_background, False, extension)
print("Reading Test Transforms") print("Reading Test Transforms")
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension) test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", depths_folder, white_background, True, extension)
if not eval: if not eval:
train_cam_infos.extend(test_cam_infos) train_cam_infos.extend(test_cam_infos)
@ -298,7 +305,8 @@ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
train_cameras=train_cam_infos, train_cameras=train_cam_infos,
test_cameras=test_cam_infos, test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization, nerf_normalization=nerf_normalization,
ply_path=ply_path) ply_path=ply_path,
is_nerf_synthetic=True)
return scene_info return scene_info
sceneLoadTypeCallbacks = { sceneLoadTypeCallbacks = {

View File

@ -17,12 +17,16 @@ import cv2
WARNED = False WARNED = False
def loadCam(args, id, cam_info, resolution_scale, is_test_dataset): def loadCam(args, id, cam_info, resolution_scale, is_nerf_synthetic, is_test_dataset):
image = Image.open(cam_info.image_path) image = Image.open(cam_info.image_path)
if cam_info.depth_path != "": if cam_info.depth_path != "":
try: try:
if is_nerf_synthetic:
invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / 512
else:
invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16) invdepthmap = cv2.imread(cam_info.depth_path, -1).astype(np.float32) / float(2**16)
except FileNotFoundError: except FileNotFoundError:
print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.") print(f"Error: The depth file at path '{cam_info.depth_path}' was not found.")
raise raise
@ -62,11 +66,11 @@ def loadCam(args, id, cam_info, resolution_scale, is_test_dataset):
image_name=cam_info.image_name, uid=id, data_device=args.data_device, image_name=cam_info.image_name, uid=id, data_device=args.data_device,
train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test) train_test_exp=args.train_test_exp, is_test_dataset=is_test_dataset, is_test_view=cam_info.is_test)
def cameraList_from_camInfos(cam_infos, resolution_scale, args, is_test_dataset): def cameraList_from_camInfos(cam_infos, resolution_scale, args, is_nerf_synthetic, is_test_dataset):
camera_list = [] camera_list = []
for id, c in enumerate(cam_infos): for id, c in enumerate(cam_infos):
camera_list.append(loadCam(args, id, c, resolution_scale, is_test_dataset)) camera_list.append(loadCam(args, id, c, resolution_scale, is_nerf_synthetic, is_test_dataset))
return camera_list return camera_list