adding post pruning based on depth maps projection

This commit is contained in:
luizgbraga 2024-10-16 12:54:38 +02:00
parent 8a70a8cd6f
commit e46855ff0e
8 changed files with 252 additions and 6 deletions

4
.gitignore vendored
View File

@ -5,4 +5,6 @@ build
diff_rasterization/diff_rast.egg-info
diff_rasterization/dist
tensorboard_3d
screenshots
screenshots
depth_maps
projected_depth

View File

@ -47,6 +47,7 @@ class ParamGroup:
class ModelParams(ParamGroup):
def __init__(self, parser, sentinel=False):
self.sh_degree = 3
self.depth_prune = False
self._source_path = ""
self._model_path = ""
self._images = "images"

View File

@ -1,4 +1,4 @@
name: gaussian_splatting
name: depth-pruned-gaussian-splatting
channels:
- pytorch
- conda-forge

View File

@ -12,11 +12,13 @@
import os
import random
import json
import torch
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
from utils.depth_map import generate_depth_map_from_tensor
class Scene:
@ -90,4 +92,22 @@ class Scene:
return self.train_cameras[scale]
def getTestCameras(self, scale=1.0):
return self.test_cameras[scale]
return self.test_cameras[scale]
def initialize_depth_maps(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_type = 'DPT_Large' # or 'MiDaS_small' for faster processing
depth_model = torch.hub.load('intel-isl/MiDaS', model_type, trust_repo=True)
depth_model.to(device)
depth_model.eval()
if model_type in ['DPT_Large', 'DPT_Hybrid']:
transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
transform = transforms.dpt_transform
else:
transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
transform = transforms.small_transform
all_cameras = self.getTrainCameras().copy() + self.getTestCameras().copy()
for cam in all_cameras:
cam.depth_map = generate_depth_map_from_tensor(cam.original_image, depth_model, transform, device, cam.image_name)

View File

@ -56,6 +56,42 @@ class Camera(nn.Module):
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
def get_depths_from_depth_map(self, pixel_coords):
"""
Fetches depth values from the depth map at the given pixel coordinates using bilinear interpolation.
"""
x_coords = pixel_coords[:, 0]
y_coords = pixel_coords[:, 1]
H, W = self.depth_map.shape
# Compute integer coordinates
x0 = torch.clamp(x_coords.floor().long(), 0, W - 1)
x1 = torch.clamp(x0 + 1, 0, W - 1)
y0 = torch.clamp(y_coords.floor().long(), 0, H - 1)
y1 = torch.clamp(y0 + 1, 0, H - 1)
# Compute fractional parts
x_frac = x_coords - x_coords.floor()
y_frac = y_coords - y_coords.floor()
device = pixel_coords.device
depth_map = self.depth_map.to(device)
Ia = depth_map[y0, x0]
Ib = depth_map[y1, x0]
Ic = depth_map[y0, x1]
Id = depth_map[y1, x1]
depths = Ia * (1 - x_frac) * (1 - y_frac) + \
Ib * (1 - x_frac) * y_frac + \
Ic * x_frac * (1 - y_frac) + \
Id * x_frac * y_frac
return depths
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width

View File

@ -14,6 +14,8 @@ import numpy as np
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
from torch import nn
import os
import cv2
from scene.cameras import Camera
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH
@ -56,6 +58,7 @@ class GaussianModel:
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self.average_depth = torch.empty(0)
self.setup_functions()
def capture(self):
@ -104,6 +107,10 @@ class GaussianModel:
def get_xyz(self):
return self._xyz
@property
def num_gaussians(self):
return self._xyz.shape[0]
@property
def get_features(self):
features_dc = self._features_dc
@ -289,7 +296,7 @@ class GaussianModel:
return optimizable_tensors
def prune_points(self, mask):
valid_points_mask = ~mask
valid_points_mask = ~mask # mask was True for points to be pruned
optimizable_tensors = self._prune_optimizer(valid_points_mask)
self._xyz = optimizable_tensors["xyz"]
@ -404,4 +411,120 @@ class GaussianModel:
def add_densification_stats(self, viewspace_point_tensor, update_filter):
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
self.denom[update_filter] += 1
self.denom[update_filter] += 1
def project_gaussians(self, camera: Camera):
"""
Projects Gaussians into image space using the camera's view and projection matrices.
"""
positions = self.get_xyz
N = positions.shape[0]
full_proj_transform = camera.full_proj_transform # [4, 4]
positions_h = positions @ full_proj_transform[:3, :3].T + full_proj_transform[3, :3] # [N, 3]
w = full_proj_transform[3, 3]
positions_clip = positions_h / w
x_ndc = positions_clip[:, 0]
y_ndc = positions_clip[:, 1]
x = (x_ndc + 1.0) * 0.5 * (camera.image_width - 1)
y = (y_ndc + 1.0) * 0.5 * (camera.image_height - 1)
pixel_coords = torch.stack([x, y], dim=1) # [N, 2]
return pixel_coords
def compute_average_depths(self, cameras, threshold = 0.8):
"""
Computes the average depth of the Gaussians in the camera's view space.
"""
N = self.get_xyz.shape[0]
device = self.get_xyz.device
total_depth = torch.zeros(N, device=device)
count = torch.zeros(N, device=device)
for camera in cameras:
pixel_coords = self.project_gaussians(camera) # [N, 2]
depths = camera.get_depths_from_depth_map(pixel_coords) # [N]
valid_mask = ~torch.isnan(depths)
if valid_mask.any():
adjusted_depths = torch.where(depths > threshold, torch.tensor(1.0, device=device), depths)
total_depth[valid_mask] += adjusted_depths[valid_mask]
count[valid_mask] += 1
average_depth = torch.full((N,), float(1), device=device)
nonzero_mask = count > 0
average_depth[nonzero_mask] = total_depth[nonzero_mask] / count[nonzero_mask]
self.average_depth = average_depth
def filter_gaussians_by_depth(self, depth_threshold):
"""
Filter the Gaussians based on the depths on their projection onto the depth map.
"""
prune_depth_mask = ~torch.isnan(self.average_depth) & (self.average_depth < depth_threshold)
if prune_depth_mask.sum() == 0:
print("No Gaussians to prune.")
return
if prune_depth_mask.sum() == self.get_xyz.shape[0]:
raise ValueError("All Gaussians are being pruned. Please adjust the depth threshold.")
self.prune_points(prune_depth_mask)
def visualize_gaussians_on_image(self, camera: Camera, image_name: str):
"""
Visualize Gaussians projected onto a camera's image, colored based on average depth.
"""
if not hasattr(self, 'average_depth') or self.average_depth is None:
raise ValueError("average_depth is not computed. Please run compute_average_depths(cameras) first.")
pixel_coords = self.project_gaussians(camera)
pixel_coords = pixel_coords.detach().cpu().numpy()
avg_depths = self.average_depth.detach().cpu().numpy()
valid_avg_depths = avg_depths[~np.isnan(avg_depths)]
if valid_avg_depths.size == 0:
raise ValueError("No valid average depths found.")
depth_min = valid_avg_depths.min()
depth_max = valid_avg_depths.max()
depth_range = depth_max - depth_min
avg_depths_clean = np.copy(avg_depths)
avg_depths_clean[np.isnan(avg_depths)] = depth_max
normalized_depths = (avg_depths_clean - depth_min) / (depth_range + 1e-8)
normalized_depths = np.clip(normalized_depths, 0, 1)
colors = (1 - normalized_depths) * 255
colors = colors.astype(np.uint8)
image_tensor = camera.original_image.detach().cpu() # Shape: [C, H, W]
image_np = image_tensor.permute(1, 2, 0).numpy()
image_np = np.clip(image_np * 255, 0, 255).astype(np.uint8)
if image_np.shape[2] == 1:
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
elif image_np.shape[2] == 3:
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
else:
raise ValueError(f"Unexpected number of channels in image: {image_np.shape[2]}")
image_height, image_width = image_np.shape[:2]
image_with_projections = image_np.copy()
x_coords = np.round(pixel_coords[:, 0]).astype(np.int32)
y_coords = np.round(pixel_coords[:, 1]).astype(np.int32)
for x, y, color in zip(x_coords, y_coords, colors):
if 0 <= x < image_width and 0 <= y < image_height:
cv2.circle(image_with_projections, (x, y), radius=2, color=(int(color), int(color), int(color)), thickness=-1)
# Save
os.makedirs('projected_depth', exist_ok=True)
output_path = f'projected_depth/{image_name}_gaussian_depth.png'
cv2.imwrite(output_path, image_with_projections)

View File

@ -107,8 +107,21 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
if iteration == opt.iterations and dataset.depth_prune:
print("Initializing depth maps")
scene.initialize_depth_maps()
print("Pruning gaussians by depth")
num_gaussians = gaussians.num_gaussians
all_cams = scene.getTrainCameras().copy() + scene.getTestCameras().copy()
gaussians.compute_average_depths(all_cams)
for cam in all_cams:
gaussians.visualize_gaussians_on_image(cam, f'before_prune_{cam.image_name}')
gaussians.filter_gaussians_by_depth(0.3)
print(f"Pruned {num_gaussians - gaussians.num_gaussians} gaussians")
scene.save(iteration)
return
scene.save(iteration)
# Densification
if iteration < opt.densify_until_iter:
# Keep track of max radii in image-space for pruning

51
utils/depth_map.py Normal file
View File

@ -0,0 +1,51 @@
import numpy as np
import torch
import cv2
import os
def generate_depth_map_from_tensor(image_tensor, model, transform, device, image_name):
"""
Generate depth map with values between 0 and 1 (where 0 is the furthest) from
a PyTorch tensor image.
"""
image_np = image_tensor.cpu().numpy().transpose(1, 2, 0)
img = np.clip(image_np * 255, 0, 255).astype(np.uint8)
input_batch = transform(img).to(device)
with torch.no_grad():
prediction = model(input_batch)
prediction_resized = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=(img.shape[0], img.shape[1]),
mode='bicubic',
align_corners=False,
).squeeze(0).squeeze(0)
depth_map = prediction_resized.cpu()
depth_min = depth_map.min()
depth_max = depth_map.max()
depth_map_normalized = (depth_map - depth_min) / (depth_max - depth_min + 1e-8)
visualize_depth_map(depth_map_normalized, image_name)
return depth_map_normalized # [H, W]
def visualize_depth_map(depth_map_normalized, image_name, percentile = 0.2):
"""
Visualize and save the depth map, coloring the most distant points in red and
the closest in green.
"""
depth_map_vis = (depth_map_normalized * 255).numpy().astype("uint8")
depth_map_color = cv2.cvtColor(depth_map_vis, cv2.COLOR_GRAY2BGR)
depth_array = depth_map_normalized.numpy()
high_depth_mask = depth_array > 1 - percentile
low_depth_mask = depth_array < percentile
depth_map_color[high_depth_mask] = [0, 255, 0] # green
depth_map_color[low_depth_mask] = [0, 0, 255] # red
os.makedirs('depth_maps', exist_ok=True)
cv2.imwrite(f"depth_maps/{image_name}_depth_map.png", depth_map_color)