licensed branch

This commit is contained in:
alanvinx 2024-10-30 14:58:17 +01:00
parent 54c035f783
commit f2d84c7ad5
4 changed files with 48 additions and 46 deletions

@ -1 +1 @@
Subproject commit d8856f60c5384cc1975439193bb627d77d917d77 Subproject commit 4c4a95365597a78b105792794db70f09a1ece938

View File

@ -5,7 +5,6 @@ channels:
- defaults - defaults
dependencies: dependencies:
- cudatoolkit=11.6 - cudatoolkit=11.6
- plyfile
- python=3.7.13 - python=3.7.13
- pip=22.3.1 - pip=22.3.1
- pytorch=1.12.1 - pytorch=1.12.1
@ -18,3 +17,4 @@ dependencies:
- submodules/fused-ssim - submodules/fused-ssim
- opencv-python - opencv-python
- joblib - joblib
- meshio

View File

@ -19,7 +19,7 @@ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np import numpy as np
import json import json
from pathlib import Path from pathlib import Path
from plyfile import PlyData, PlyElement import meshio
from utils.sh_utils import SH2RGB from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud from scene.gaussian_model import BasicPointCloud
@ -118,29 +118,31 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_fold
return cam_infos return cam_infos
def fetchPly(path): def fetchPly(path):
plydata = PlyData.read(path) vertices = meshio.read(path)
vertices = plydata['vertex'] positions = vertices.points
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T colors = np.vstack(
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 [
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T vertices.point_data['red'].astype(np.uint8),
vertices.point_data['green'].astype(np.uint8),
vertices.point_data['blue'].astype(np.uint8)
]).T / 255.0
normals = np.vstack([vertices.point_data['nx'], vertices.point_data['ny'], vertices.point_data['nz']]).T
return BasicPointCloud(points=positions, colors=colors, normals=normals) return BasicPointCloud(points=positions, colors=colors, normals=normals)
def storePly(path, xyz, rgb): def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz) normals = np.zeros_like(xyz)
point_data = {
"red": rgb[..., 0].astype(np.uint8),
"green": rgb[..., 1].astype(np.uint8),
"blue": rgb[..., 2].astype(np.uint8),
"nx": normals[..., 0].astype(np.float32),
"ny": normals[..., 1].astype(np.float32),
"nz": normals[..., 2].astype(np.float32),
}
elements = np.empty(xyz.shape[0], dtype=dtype) mesh = meshio.Mesh(points=xyz.astype(np.float32), point_data=point_data, cells=[])
attributes = np.concatenate((xyz, normals, rgb), axis=1) meshio.write(path, mesh)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)
def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8): def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
try: try:

View File

@ -16,7 +16,7 @@ from torch import nn
import os import os
import json import json
from utils.system_utils import mkdir_p from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement import meshio
from utils.sh_utils import RGB2SH from utils.sh_utils import RGB2SH
from simple_knn._C import distCUDA2 from simple_knn._C import distCUDA2
from utils.graphics_utils import BasicPointCloud from utils.graphics_utils import BasicPointCloud
@ -246,14 +246,15 @@ class GaussianModel:
opacities = self._opacity.detach().cpu().numpy() opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy() scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy() rotation = self._rotation.detach().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] point_data = {}
attribs_no_pos = [attribute for attribute in self.construct_list_of_attributes() if attribute not in ["x", "y", "z"]]
elements = np.empty(xyz.shape[0], dtype=dtype_full) values = np.concatenate((normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) for index, attribute in enumerate(attribs_no_pos):
elements[:] = list(map(tuple, attributes)) point_data[attribute] = values[..., index]
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path) mesh = meshio.Mesh(points=xyz.astype(np.float32), point_data=point_data, cells=[])
meshio.write(path, mesh)
def reset_opacity(self): def reset_opacity(self):
opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
@ -261,7 +262,6 @@ class GaussianModel:
self._opacity = optimizable_tensors["opacity"] self._opacity = optimizable_tensors["opacity"]
def load_ply(self, path, use_train_test_exp = False): def load_ply(self, path, use_train_test_exp = False):
plydata = PlyData.read(path)
if use_train_test_exp: if use_train_test_exp:
exposure_file = os.path.join(os.path.dirname(path), os.pardir, os.pardir, "exposure.json") exposure_file = os.path.join(os.path.dirname(path), os.pardir, os.pardir, "exposure.json")
if os.path.exists(exposure_file): if os.path.exists(exposure_file):
@ -273,36 +273,36 @@ class GaussianModel:
print(f"No exposure to be loaded at {exposure_file}") print(f"No exposure to be loaded at {exposure_file}")
self.pretrained_exposures = None self.pretrained_exposures = None
xyz = np.stack((np.asarray(plydata.elements[0]["x"]), vertices = meshio.read(path)
np.asarray(plydata.elements[0]["y"]), xyz = vertices.points
np.asarray(plydata.elements[0]["z"])), axis=1) point_data = vertices.point_data
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] opacities = np.asarray(point_data["opacity"])[..., np.newaxis]
features_dc = np.zeros((xyz.shape[0], 3, 1)) features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) features_dc[:, 0, 0] = np.asarray(point_data["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) features_dc[:, 1, 0] = np.asarray(point_data["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) features_dc[:, 2, 0] = np.asarray(point_data["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] extra_f_names = [p for p in point_data if p.startswith("f_rest_")]
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names): for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) features_extra[:, idx] = np.asarray(point_data[attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] scale_names = [p for p in point_data if p.startswith("scale_")]
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names))) scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names): for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) scales[:, idx] = np.asarray(point_data[attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] rot_names = [p for p in point_data if p.startswith("rot")]
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names))) rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names): for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) rots[:, idx] = np.asarray(point_data[attr_name])
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))