mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-02 20:30:40 +00:00
adding post pruning based on depth maps projection
This commit is contained in:
parent
8a70a8cd6f
commit
e46855ff0e
4
.gitignore
vendored
4
.gitignore
vendored
@ -5,4 +5,6 @@ build
|
||||
diff_rasterization/diff_rast.egg-info
|
||||
diff_rasterization/dist
|
||||
tensorboard_3d
|
||||
screenshots
|
||||
screenshots
|
||||
depth_maps
|
||||
projected_depth
|
@ -47,6 +47,7 @@ class ParamGroup:
|
||||
class ModelParams(ParamGroup):
|
||||
def __init__(self, parser, sentinel=False):
|
||||
self.sh_degree = 3
|
||||
self.depth_prune = False
|
||||
self._source_path = ""
|
||||
self._model_path = ""
|
||||
self._images = "images"
|
||||
|
@ -1,4 +1,4 @@
|
||||
name: gaussian_splatting
|
||||
name: depth-pruned-gaussian-splatting
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
|
@ -12,11 +12,13 @@
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
import torch
|
||||
from utils.system_utils import searchForMaxIteration
|
||||
from scene.dataset_readers import sceneLoadTypeCallbacks
|
||||
from scene.gaussian_model import GaussianModel
|
||||
from arguments import ModelParams
|
||||
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
|
||||
from utils.depth_map import generate_depth_map_from_tensor
|
||||
|
||||
class Scene:
|
||||
|
||||
@ -90,4 +92,22 @@ class Scene:
|
||||
return self.train_cameras[scale]
|
||||
|
||||
def getTestCameras(self, scale=1.0):
|
||||
return self.test_cameras[scale]
|
||||
return self.test_cameras[scale]
|
||||
|
||||
def initialize_depth_maps(self):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model_type = 'DPT_Large' # or 'MiDaS_small' for faster processing
|
||||
depth_model = torch.hub.load('intel-isl/MiDaS', model_type, trust_repo=True)
|
||||
depth_model.to(device)
|
||||
depth_model.eval()
|
||||
|
||||
if model_type in ['DPT_Large', 'DPT_Hybrid']:
|
||||
transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
|
||||
transform = transforms.dpt_transform
|
||||
else:
|
||||
transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
|
||||
transform = transforms.small_transform
|
||||
|
||||
all_cameras = self.getTrainCameras().copy() + self.getTestCameras().copy()
|
||||
for cam in all_cameras:
|
||||
cam.depth_map = generate_depth_map_from_tensor(cam.original_image, depth_model, transform, device, cam.image_name)
|
||||
|
@ -56,6 +56,42 @@ class Camera(nn.Module):
|
||||
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
||||
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
||||
|
||||
|
||||
def get_depths_from_depth_map(self, pixel_coords):
|
||||
"""
|
||||
Fetches depth values from the depth map at the given pixel coordinates using bilinear interpolation.
|
||||
"""
|
||||
x_coords = pixel_coords[:, 0]
|
||||
y_coords = pixel_coords[:, 1]
|
||||
|
||||
H, W = self.depth_map.shape
|
||||
|
||||
# Compute integer coordinates
|
||||
x0 = torch.clamp(x_coords.floor().long(), 0, W - 1)
|
||||
x1 = torch.clamp(x0 + 1, 0, W - 1)
|
||||
y0 = torch.clamp(y_coords.floor().long(), 0, H - 1)
|
||||
y1 = torch.clamp(y0 + 1, 0, H - 1)
|
||||
|
||||
# Compute fractional parts
|
||||
x_frac = x_coords - x_coords.floor()
|
||||
y_frac = y_coords - y_coords.floor()
|
||||
|
||||
device = pixel_coords.device
|
||||
depth_map = self.depth_map.to(device)
|
||||
|
||||
Ia = depth_map[y0, x0]
|
||||
Ib = depth_map[y1, x0]
|
||||
Ic = depth_map[y0, x1]
|
||||
Id = depth_map[y1, x1]
|
||||
|
||||
depths = Ia * (1 - x_frac) * (1 - y_frac) + \
|
||||
Ib * (1 - x_frac) * y_frac + \
|
||||
Ic * x_frac * (1 - y_frac) + \
|
||||
Id * x_frac * y_frac
|
||||
|
||||
return depths
|
||||
|
||||
|
||||
class MiniCam:
|
||||
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
||||
self.image_width = width
|
||||
|
@ -14,6 +14,8 @@ import numpy as np
|
||||
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
||||
from torch import nn
|
||||
import os
|
||||
import cv2
|
||||
from scene.cameras import Camera
|
||||
from utils.system_utils import mkdir_p
|
||||
from plyfile import PlyData, PlyElement
|
||||
from utils.sh_utils import RGB2SH
|
||||
@ -56,6 +58,7 @@ class GaussianModel:
|
||||
self.optimizer = None
|
||||
self.percent_dense = 0
|
||||
self.spatial_lr_scale = 0
|
||||
self.average_depth = torch.empty(0)
|
||||
self.setup_functions()
|
||||
|
||||
def capture(self):
|
||||
@ -104,6 +107,10 @@ class GaussianModel:
|
||||
def get_xyz(self):
|
||||
return self._xyz
|
||||
|
||||
@property
|
||||
def num_gaussians(self):
|
||||
return self._xyz.shape[0]
|
||||
|
||||
@property
|
||||
def get_features(self):
|
||||
features_dc = self._features_dc
|
||||
@ -289,7 +296,7 @@ class GaussianModel:
|
||||
return optimizable_tensors
|
||||
|
||||
def prune_points(self, mask):
|
||||
valid_points_mask = ~mask
|
||||
valid_points_mask = ~mask # mask was True for points to be pruned
|
||||
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
||||
|
||||
self._xyz = optimizable_tensors["xyz"]
|
||||
@ -404,4 +411,120 @@ class GaussianModel:
|
||||
|
||||
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
||||
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
||||
self.denom[update_filter] += 1
|
||||
self.denom[update_filter] += 1
|
||||
|
||||
def project_gaussians(self, camera: Camera):
|
||||
"""
|
||||
Projects Gaussians into image space using the camera's view and projection matrices.
|
||||
"""
|
||||
positions = self.get_xyz
|
||||
N = positions.shape[0]
|
||||
full_proj_transform = camera.full_proj_transform # [4, 4]
|
||||
|
||||
positions_h = positions @ full_proj_transform[:3, :3].T + full_proj_transform[3, :3] # [N, 3]
|
||||
w = full_proj_transform[3, 3]
|
||||
positions_clip = positions_h / w
|
||||
|
||||
x_ndc = positions_clip[:, 0]
|
||||
y_ndc = positions_clip[:, 1]
|
||||
|
||||
x = (x_ndc + 1.0) * 0.5 * (camera.image_width - 1)
|
||||
y = (y_ndc + 1.0) * 0.5 * (camera.image_height - 1)
|
||||
|
||||
pixel_coords = torch.stack([x, y], dim=1) # [N, 2]
|
||||
|
||||
return pixel_coords
|
||||
|
||||
def compute_average_depths(self, cameras, threshold = 0.8):
|
||||
"""
|
||||
Computes the average depth of the Gaussians in the camera's view space.
|
||||
"""
|
||||
N = self.get_xyz.shape[0]
|
||||
device = self.get_xyz.device
|
||||
total_depth = torch.zeros(N, device=device)
|
||||
count = torch.zeros(N, device=device)
|
||||
|
||||
for camera in cameras:
|
||||
pixel_coords = self.project_gaussians(camera) # [N, 2]
|
||||
depths = camera.get_depths_from_depth_map(pixel_coords) # [N]
|
||||
|
||||
valid_mask = ~torch.isnan(depths)
|
||||
|
||||
if valid_mask.any():
|
||||
adjusted_depths = torch.where(depths > threshold, torch.tensor(1.0, device=device), depths)
|
||||
|
||||
total_depth[valid_mask] += adjusted_depths[valid_mask]
|
||||
count[valid_mask] += 1
|
||||
|
||||
average_depth = torch.full((N,), float(1), device=device)
|
||||
nonzero_mask = count > 0
|
||||
average_depth[nonzero_mask] = total_depth[nonzero_mask] / count[nonzero_mask]
|
||||
self.average_depth = average_depth
|
||||
|
||||
def filter_gaussians_by_depth(self, depth_threshold):
|
||||
"""
|
||||
Filter the Gaussians based on the depths on their projection onto the depth map.
|
||||
"""
|
||||
prune_depth_mask = ~torch.isnan(self.average_depth) & (self.average_depth < depth_threshold)
|
||||
|
||||
if prune_depth_mask.sum() == 0:
|
||||
print("No Gaussians to prune.")
|
||||
return
|
||||
|
||||
if prune_depth_mask.sum() == self.get_xyz.shape[0]:
|
||||
raise ValueError("All Gaussians are being pruned. Please adjust the depth threshold.")
|
||||
|
||||
self.prune_points(prune_depth_mask)
|
||||
|
||||
def visualize_gaussians_on_image(self, camera: Camera, image_name: str):
|
||||
"""
|
||||
Visualize Gaussians projected onto a camera's image, colored based on average depth.
|
||||
"""
|
||||
if not hasattr(self, 'average_depth') or self.average_depth is None:
|
||||
raise ValueError("average_depth is not computed. Please run compute_average_depths(cameras) first.")
|
||||
|
||||
pixel_coords = self.project_gaussians(camera)
|
||||
pixel_coords = pixel_coords.detach().cpu().numpy()
|
||||
avg_depths = self.average_depth.detach().cpu().numpy()
|
||||
valid_avg_depths = avg_depths[~np.isnan(avg_depths)]
|
||||
if valid_avg_depths.size == 0:
|
||||
raise ValueError("No valid average depths found.")
|
||||
|
||||
depth_min = valid_avg_depths.min()
|
||||
depth_max = valid_avg_depths.max()
|
||||
depth_range = depth_max - depth_min
|
||||
|
||||
avg_depths_clean = np.copy(avg_depths)
|
||||
avg_depths_clean[np.isnan(avg_depths)] = depth_max
|
||||
|
||||
normalized_depths = (avg_depths_clean - depth_min) / (depth_range + 1e-8)
|
||||
normalized_depths = np.clip(normalized_depths, 0, 1)
|
||||
|
||||
colors = (1 - normalized_depths) * 255
|
||||
colors = colors.astype(np.uint8)
|
||||
|
||||
image_tensor = camera.original_image.detach().cpu() # Shape: [C, H, W]
|
||||
image_np = image_tensor.permute(1, 2, 0).numpy()
|
||||
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
if image_np.shape[2] == 1:
|
||||
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
|
||||
elif image_np.shape[2] == 3:
|
||||
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
raise ValueError(f"Unexpected number of channels in image: {image_np.shape[2]}")
|
||||
|
||||
image_height, image_width = image_np.shape[:2]
|
||||
image_with_projections = image_np.copy()
|
||||
|
||||
x_coords = np.round(pixel_coords[:, 0]).astype(np.int32)
|
||||
y_coords = np.round(pixel_coords[:, 1]).astype(np.int32)
|
||||
|
||||
for x, y, color in zip(x_coords, y_coords, colors):
|
||||
if 0 <= x < image_width and 0 <= y < image_height:
|
||||
cv2.circle(image_with_projections, (x, y), radius=2, color=(int(color), int(color), int(color)), thickness=-1)
|
||||
|
||||
# Save
|
||||
os.makedirs('projected_depth', exist_ok=True)
|
||||
output_path = f'projected_depth/{image_name}_gaussian_depth.png'
|
||||
cv2.imwrite(output_path, image_with_projections)
|
||||
|
15
train.py
15
train.py
@ -107,8 +107,21 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
||||
if (iteration in saving_iterations):
|
||||
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
||||
if iteration == opt.iterations and dataset.depth_prune:
|
||||
print("Initializing depth maps")
|
||||
scene.initialize_depth_maps()
|
||||
print("Pruning gaussians by depth")
|
||||
num_gaussians = gaussians.num_gaussians
|
||||
all_cams = scene.getTrainCameras().copy() + scene.getTestCameras().copy()
|
||||
gaussians.compute_average_depths(all_cams)
|
||||
for cam in all_cams:
|
||||
gaussians.visualize_gaussians_on_image(cam, f'before_prune_{cam.image_name}')
|
||||
gaussians.filter_gaussians_by_depth(0.3)
|
||||
print(f"Pruned {num_gaussians - gaussians.num_gaussians} gaussians")
|
||||
scene.save(iteration)
|
||||
return
|
||||
scene.save(iteration)
|
||||
|
||||
|
||||
# Densification
|
||||
if iteration < opt.densify_until_iter:
|
||||
# Keep track of max radii in image-space for pruning
|
||||
|
51
utils/depth_map.py
Normal file
51
utils/depth_map.py
Normal file
@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import os
|
||||
|
||||
def generate_depth_map_from_tensor(image_tensor, model, transform, device, image_name):
|
||||
"""
|
||||
Generate depth map with values between 0 and 1 (where 0 is the furthest) from
|
||||
a PyTorch tensor image.
|
||||
"""
|
||||
image_np = image_tensor.cpu().numpy().transpose(1, 2, 0)
|
||||
img = np.clip(image_np * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
input_batch = transform(img).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
prediction = model(input_batch)
|
||||
prediction_resized = torch.nn.functional.interpolate(
|
||||
prediction.unsqueeze(1),
|
||||
size=(img.shape[0], img.shape[1]),
|
||||
mode='bicubic',
|
||||
align_corners=False,
|
||||
).squeeze(0).squeeze(0)
|
||||
|
||||
depth_map = prediction_resized.cpu()
|
||||
|
||||
depth_min = depth_map.min()
|
||||
depth_max = depth_map.max()
|
||||
depth_map_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
|
||||
|
||||
visualize_depth_map(depth_map_normalized, image_name)
|
||||
|
||||
return depth_map_normalized # [H, W]
|
||||
|
||||
def visualize_depth_map(depth_map_normalized, image_name, percentile = 0.2):
|
||||
"""
|
||||
Visualize and save the depth map, coloring the most distant points in red and
|
||||
the closest in green.
|
||||
"""
|
||||
depth_map_vis = (depth_map_normalized * 255).numpy().astype("uint8")
|
||||
|
||||
depth_map_color = cv2.cvtColor(depth_map_vis, cv2.COLOR_GRAY2BGR)
|
||||
depth_array = depth_map_normalized.numpy()
|
||||
|
||||
high_depth_mask = depth_array > 1 - percentile
|
||||
low_depth_mask = depth_array < percentile
|
||||
depth_map_color[high_depth_mask] = [0, 255, 0] # green
|
||||
depth_map_color[low_depth_mask] = [0, 0, 255] # red
|
||||
|
||||
os.makedirs('depth_maps', exist_ok=True)
|
||||
cv2.imwrite(f"depth_maps/{image_name}_depth_map.png", depth_map_color)
|
Loading…
Reference in New Issue
Block a user