mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-21 15:57:45 +00:00
licensed branch
This commit is contained in:
parent
54c035f783
commit
f2d84c7ad5
@ -1 +1 @@
|
|||||||
Subproject commit d8856f60c5384cc1975439193bb627d77d917d77
|
Subproject commit 4c4a95365597a78b105792794db70f09a1ece938
|
@ -5,7 +5,6 @@ channels:
|
|||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- cudatoolkit=11.6
|
- cudatoolkit=11.6
|
||||||
- plyfile
|
|
||||||
- python=3.7.13
|
- python=3.7.13
|
||||||
- pip=22.3.1
|
- pip=22.3.1
|
||||||
- pytorch=1.12.1
|
- pytorch=1.12.1
|
||||||
@ -18,3 +17,4 @@ dependencies:
|
|||||||
- submodules/fused-ssim
|
- submodules/fused-ssim
|
||||||
- opencv-python
|
- opencv-python
|
||||||
- joblib
|
- joblib
|
||||||
|
- meshio
|
||||||
|
@ -19,7 +19,7 @@ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from plyfile import PlyData, PlyElement
|
import meshio
|
||||||
from utils.sh_utils import SH2RGB
|
from utils.sh_utils import SH2RGB
|
||||||
from scene.gaussian_model import BasicPointCloud
|
from scene.gaussian_model import BasicPointCloud
|
||||||
|
|
||||||
@ -118,29 +118,31 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_fold
|
|||||||
return cam_infos
|
return cam_infos
|
||||||
|
|
||||||
def fetchPly(path):
|
def fetchPly(path):
|
||||||
plydata = PlyData.read(path)
|
vertices = meshio.read(path)
|
||||||
vertices = plydata['vertex']
|
positions = vertices.points
|
||||||
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
colors = np.vstack(
|
||||||
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
|
[
|
||||||
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
|
vertices.point_data['red'].astype(np.uint8),
|
||||||
|
vertices.point_data['green'].astype(np.uint8),
|
||||||
|
vertices.point_data['blue'].astype(np.uint8)
|
||||||
|
]).T / 255.0
|
||||||
|
|
||||||
|
normals = np.vstack([vertices.point_data['nx'], vertices.point_data['ny'], vertices.point_data['nz']]).T
|
||||||
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
||||||
|
|
||||||
def storePly(path, xyz, rgb):
|
def storePly(path, xyz, rgb):
|
||||||
# Define the dtype for the structured array
|
|
||||||
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
|
||||||
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
|
||||||
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
|
||||||
|
|
||||||
normals = np.zeros_like(xyz)
|
normals = np.zeros_like(xyz)
|
||||||
|
point_data = {
|
||||||
|
"red": rgb[..., 0].astype(np.uint8),
|
||||||
|
"green": rgb[..., 1].astype(np.uint8),
|
||||||
|
"blue": rgb[..., 2].astype(np.uint8),
|
||||||
|
"nx": normals[..., 0].astype(np.float32),
|
||||||
|
"ny": normals[..., 1].astype(np.float32),
|
||||||
|
"nz": normals[..., 2].astype(np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
elements = np.empty(xyz.shape[0], dtype=dtype)
|
mesh = meshio.Mesh(points=xyz.astype(np.float32), point_data=point_data, cells=[])
|
||||||
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
meshio.write(path, mesh)
|
||||||
elements[:] = list(map(tuple, attributes))
|
|
||||||
|
|
||||||
# Create the PlyData object and write to file
|
|
||||||
vertex_element = PlyElement.describe(elements, 'vertex')
|
|
||||||
ply_data = PlyData([vertex_element])
|
|
||||||
ply_data.write(path)
|
|
||||||
|
|
||||||
def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
|
def readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8):
|
||||||
try:
|
try:
|
||||||
|
@ -16,7 +16,7 @@ from torch import nn
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from utils.system_utils import mkdir_p
|
from utils.system_utils import mkdir_p
|
||||||
from plyfile import PlyData, PlyElement
|
import meshio
|
||||||
from utils.sh_utils import RGB2SH
|
from utils.sh_utils import RGB2SH
|
||||||
from simple_knn._C import distCUDA2
|
from simple_knn._C import distCUDA2
|
||||||
from utils.graphics_utils import BasicPointCloud
|
from utils.graphics_utils import BasicPointCloud
|
||||||
@ -246,14 +246,15 @@ class GaussianModel:
|
|||||||
opacities = self._opacity.detach().cpu().numpy()
|
opacities = self._opacity.detach().cpu().numpy()
|
||||||
scale = self._scaling.detach().cpu().numpy()
|
scale = self._scaling.detach().cpu().numpy()
|
||||||
rotation = self._rotation.detach().cpu().numpy()
|
rotation = self._rotation.detach().cpu().numpy()
|
||||||
|
|
||||||
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
point_data = {}
|
||||||
|
attribs_no_pos = [attribute for attribute in self.construct_list_of_attributes() if attribute not in ["x", "y", "z"]]
|
||||||
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
values = np.concatenate((normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
||||||
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
for index, attribute in enumerate(attribs_no_pos):
|
||||||
elements[:] = list(map(tuple, attributes))
|
point_data[attribute] = values[..., index]
|
||||||
el = PlyElement.describe(elements, 'vertex')
|
|
||||||
PlyData([el]).write(path)
|
mesh = meshio.Mesh(points=xyz.astype(np.float32), point_data=point_data, cells=[])
|
||||||
|
meshio.write(path, mesh)
|
||||||
|
|
||||||
def reset_opacity(self):
|
def reset_opacity(self):
|
||||||
opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
opacities_new = self.inverse_opacity_activation(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
||||||
@ -261,7 +262,6 @@ class GaussianModel:
|
|||||||
self._opacity = optimizable_tensors["opacity"]
|
self._opacity = optimizable_tensors["opacity"]
|
||||||
|
|
||||||
def load_ply(self, path, use_train_test_exp = False):
|
def load_ply(self, path, use_train_test_exp = False):
|
||||||
plydata = PlyData.read(path)
|
|
||||||
if use_train_test_exp:
|
if use_train_test_exp:
|
||||||
exposure_file = os.path.join(os.path.dirname(path), os.pardir, os.pardir, "exposure.json")
|
exposure_file = os.path.join(os.path.dirname(path), os.pardir, os.pardir, "exposure.json")
|
||||||
if os.path.exists(exposure_file):
|
if os.path.exists(exposure_file):
|
||||||
@ -273,36 +273,36 @@ class GaussianModel:
|
|||||||
print(f"No exposure to be loaded at {exposure_file}")
|
print(f"No exposure to be loaded at {exposure_file}")
|
||||||
self.pretrained_exposures = None
|
self.pretrained_exposures = None
|
||||||
|
|
||||||
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
vertices = meshio.read(path)
|
||||||
np.asarray(plydata.elements[0]["y"]),
|
xyz = vertices.points
|
||||||
np.asarray(plydata.elements[0]["z"])), axis=1)
|
point_data = vertices.point_data
|
||||||
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
opacities = np.asarray(point_data["opacity"])[..., np.newaxis]
|
||||||
|
|
||||||
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
||||||
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
features_dc[:, 0, 0] = np.asarray(point_data["f_dc_0"])
|
||||||
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
features_dc[:, 1, 0] = np.asarray(point_data["f_dc_1"])
|
||||||
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
features_dc[:, 2, 0] = np.asarray(point_data["f_dc_2"])
|
||||||
|
|
||||||
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
extra_f_names = [p for p in point_data if p.startswith("f_rest_")]
|
||||||
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
||||||
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
|
||||||
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
||||||
for idx, attr_name in enumerate(extra_f_names):
|
for idx, attr_name in enumerate(extra_f_names):
|
||||||
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
features_extra[:, idx] = np.asarray(point_data[attr_name])
|
||||||
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
|
||||||
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
||||||
|
|
||||||
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
scale_names = [p for p in point_data if p.startswith("scale_")]
|
||||||
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
||||||
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
||||||
for idx, attr_name in enumerate(scale_names):
|
for idx, attr_name in enumerate(scale_names):
|
||||||
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
scales[:, idx] = np.asarray(point_data[attr_name])
|
||||||
|
|
||||||
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
rot_names = [p for p in point_data if p.startswith("rot")]
|
||||||
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
||||||
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
||||||
for idx, attr_name in enumerate(rot_names):
|
for idx, attr_name in enumerate(rot_names):
|
||||||
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
rots[:, idx] = np.asarray(point_data[attr_name])
|
||||||
|
|
||||||
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
||||||
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
||||||
|
Loading…
Reference in New Issue
Block a user