mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-06-26 18:18:11 +00:00
add aug_gs code
This commit is contained in:
parent
54c035f783
commit
48ceb9419b
208
augment.py
Normal file
208
augment.py
Normal file
@ -0,0 +1,208 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from utils.colmap_utils import *
|
||||
from utils.bundle_utils import cluster_cameras
|
||||
from utils.aug_utils import *
|
||||
|
||||
|
||||
def augment(colmap_path, image_path, augment_path, camera_order, visibility_aware_culling, compare_center_patch):
|
||||
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
|
||||
sorted_keys = cluster_cameras(colmap_path, camera_order)
|
||||
|
||||
points3d = []
|
||||
points3d_rgb = []
|
||||
for key in sorted(colmap_points3D.keys()):
|
||||
points3d.append(colmap_points3D[key].xyz)
|
||||
points3d_rgb.append(colmap_points3D[key].rgb)
|
||||
points3d = np.array(points3d)
|
||||
points3d_rgb = np.array(points3d_rgb)
|
||||
|
||||
image_sample = cv2.imread(os.path.join(image_path, colmap_images[sorted_keys[0]].name))
|
||||
intrinsics_camera = compute_intrinsics(colmap_cameras, image_sample.shape[1], image_sample.shape[0])
|
||||
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||
|
||||
count = 0
|
||||
roots = {}
|
||||
pbar = tqdm(len(sorted_keys))
|
||||
for view_idx in pbar:
|
||||
view = sorted_keys[view_idx]
|
||||
view_root, augmented_count = image_quadtree_augmentation(
|
||||
view,
|
||||
colmap_cameras,
|
||||
colmap_images,
|
||||
colmap_points3D,
|
||||
points3d,
|
||||
points3d_rgb,
|
||||
intrinsics_camera,
|
||||
rotations_image,
|
||||
translations_image,
|
||||
visibility_aware_culling,
|
||||
)
|
||||
count += augmented_count
|
||||
pbar.set_description(f"{count} points augmented")
|
||||
roots[view] = view_root
|
||||
|
||||
for view1_idx in tqdm(range(len(sorted_keys))):
|
||||
for view2_idx in [view_idx + 6,
|
||||
view_idx + 5,
|
||||
view_idx + 4,
|
||||
view_idx + 3,
|
||||
view_idx + 2,
|
||||
view_idx + 1,
|
||||
view_idx - 1,
|
||||
view_idx - 2,
|
||||
view_idx - 3,
|
||||
view_idx - 4,
|
||||
view_idx - 5,
|
||||
view_idx - 6]:
|
||||
if view2_idx > len(sorted_keys) - 1:
|
||||
view2_idx = view2_idx - len(sorted_keys)
|
||||
view1 = sorted_keys[view1_idx]
|
||||
view2 = sorted_keys[view2_idx]
|
||||
view1_root = roots[view1]
|
||||
view2_root = roots[view2]
|
||||
|
||||
image_view2 = cv2.imread(os.path.join(image_path, colmap_images[view2].name))
|
||||
|
||||
view1_sample_points_world, view1_sample_points_rgb = transform_sample_3d(view1,
|
||||
view1_root,
|
||||
colmap_images,
|
||||
colmap_cameras,
|
||||
intrinsics_camera,
|
||||
rotations_image,
|
||||
translations_image)
|
||||
view1_sample_points_view2, view1_sample_points_view2_depth = project_3d_to_2d(view1_sample_points_world,
|
||||
intrinsics_camera[colmap_images[view2].camera_id],
|
||||
np.concatenate((np.array(rotations_image[view2]),
|
||||
np.array(translations_image[view2]).reshape(3,1)),
|
||||
axis=1))
|
||||
points3d_view2_pixcoord, points3d_view2_depth = project_3d_to_2d(points3d,
|
||||
intrinsics_camera[colmap_images[view2].camera_id],
|
||||
np.concatenate((np.array(rotations_image[view2]),
|
||||
np.array(translations_image[view2]).reshape(3,1)),
|
||||
axis=1))
|
||||
|
||||
matching_log = []
|
||||
for i in range(view1_sample_points_world.shape[0]):
|
||||
x, y = view1_sample_points_view2[i]
|
||||
corresponding_node_type = None
|
||||
error = None
|
||||
if (view1_sample_points_view2_depth[i] < 0) | \
|
||||
(view1_sample_points_view2[i, 0] < 0) | \
|
||||
(view1_sample_points_view2[i, 0] >= image_view2.shape[1]) | \
|
||||
(view1_sample_points_view2[i, 1] < 0) | \
|
||||
(view1_sample_points_view2[i, 1] >= image_view2.shape[0]) | \
|
||||
np.isnan(view1_sample_points_view2[i]).any(axis=0):
|
||||
corresponding_node_type = "culled"
|
||||
matching_log.append([view2, corresponding_node_type, error])
|
||||
continue
|
||||
else:
|
||||
view2_corresponding_node = find_leaf_node(view2_root, x, y)
|
||||
if view2_corresponding_node is None:
|
||||
corresponding_node_type = "missing"
|
||||
else:
|
||||
if view2_corresponding_node.unoccupied:
|
||||
if view2_corresponding_node.depth_interpolated:
|
||||
error = np.linalg.norm(view1_sample_points_view2_depth[i] - view2_corresponding_node.sampled_point_depth)
|
||||
if error < 0.2 * view2_corresponding_node.sampled_point_depth:
|
||||
if compare_center_patch:
|
||||
try:
|
||||
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
|
||||
int(view1_sample_points_view2[i,1])+2,
|
||||
int(view1_sample_points_view2[i, 0])-1:\
|
||||
int(view1_sample_points_view2[i,0])+2]
|
||||
view2_corresponding_node_patch = image_view2[int(view2_corresponding_node.sampled_point_uv[1])-1:\
|
||||
int(view2_corresponding_node.sampled_point_uv[1])+2,
|
||||
int(view2_corresponding_node.sampled_point_uv[0])-1:\
|
||||
int(view2_corresponding_node.sampled_point_uv[0])+2]
|
||||
if compare_local_texture(view1_sample_point_patch, view2_corresponding_node_patch) > 0.5:
|
||||
corresponding_node_type = "sampledrejected"
|
||||
else:
|
||||
corresponding_node_type = "sampled"
|
||||
except IndexError:
|
||||
corresponding_node_type = "sampledrejected"
|
||||
else:
|
||||
corresponding_node_type = "sampled"
|
||||
|
||||
else:
|
||||
corresponding_node_type = "sampledrejected"
|
||||
else:
|
||||
corresponding_node_type = "depthrejected"
|
||||
else:
|
||||
corresponding_3d_depth = np.array(view2_corresponding_node.points3d_depths)
|
||||
error = np.linalg.norm(view1_sample_points_view2_depth[i] - corresponding_3d_depth)
|
||||
|
||||
if np.min(error) < 0.2 * corresponding_3d_depth[np.argmin(error)]:
|
||||
if compare_center_patch:
|
||||
try:
|
||||
point_3d_coord = points3d_view2_pixcoord[view2_corresponding_node.points3d_indices[np.argmin[error]]]
|
||||
point_3d_patch = image_view2[int(point_3d_coord[1])-1:\
|
||||
int(point_3d_coord[1])+2,
|
||||
int(point_3d_coord[0])-1:\
|
||||
int(point_3d_coord[0])+2]
|
||||
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
|
||||
int(view1_sample_points_view2[i,1])+2,
|
||||
int(view1_sample_points_view2[i, 0])-1:\
|
||||
int(view1_sample_points_view2[i,0])+2]
|
||||
if compare_local_texture(view1_sample_point_patch, point_3d_patch) > 0.5:
|
||||
corresponding_node_type = "rejectedoccupied3d"
|
||||
else:
|
||||
corresponding_node_type = "occupied3d"
|
||||
except IndexError:
|
||||
corresponding_node_type = "rejectedoccupied3d"
|
||||
else:
|
||||
corresponding_node_type = "occupied3d"
|
||||
else:
|
||||
corresponding_node_type = "rejectedoccupied3d"
|
||||
matching_log.append([view2, corresponding_node_type, error])
|
||||
|
||||
sampled_points_total = []
|
||||
sampled_points_rgb_total = []
|
||||
sampled_points_uv_total = []
|
||||
sampled_points_neighbors_uv_total = []
|
||||
for view in sorted_keys:
|
||||
view_root = roots[view]
|
||||
leaf_nodes = []
|
||||
gather_leaf_nodes(view_root, leaf_nodes)
|
||||
for node in leaf_nodes:
|
||||
if node.unoccupied:
|
||||
if node.depth_interpolated:
|
||||
if node.inference_count > 0:
|
||||
if node.inference_count - node.rejection_count >= 1:
|
||||
sampled_points_total.append([node.sampled_point_world])
|
||||
sampled_points_rgb_total.append([node.sampled_point_rgb])
|
||||
sampled_points_uv_total.append([node.sampled_point_uv])
|
||||
sampled_points_neighbors_uv_total.append([node.sampled_point_neighbors_uv])
|
||||
print("total_Sampled_points: ", len(sampled_points_total))
|
||||
xyz = np.concatenate(sampled_points_total, axis=0)
|
||||
rgb = np.concatenate(sampled_points_rgb_total, axis=0)
|
||||
|
||||
last_index = write_points3D_colmap_binary(colmap_points3D, xyz, rgb, augment_path)
|
||||
print("last_index: ", last_index)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--colmap_path", type=str, required=True)
|
||||
parser.add_argument("--image_path", type=str, required=True)
|
||||
parser.add_argument("--augment_path", type=str, required=True)
|
||||
parser.add_argument("--camera_order", type=str, required=True, default="covisibility")
|
||||
parser.add_argument("--visibility_aware_culling",
|
||||
action="store_true",
|
||||
default=False)
|
||||
parser.add_argument("--compare_center_patch",
|
||||
action="store_true",
|
||||
default=False)
|
||||
args = parser.parse_args()
|
||||
print("args.colmap_path", args.colmap_path)
|
||||
print("args.image_path", args.image_path)
|
||||
print("args.augment_path", args.augment_path)
|
||||
print("args.camera_order", args.camera_order)
|
||||
print("args.visibility_aware_culling", args.visibility_aware_culling)
|
||||
print("args.compare_center_patch", args.compare_center_patch)
|
||||
augment(args.colmap_path, args.image_path, args.augment_path, args.camera_order, args.visibility_aware_culling, args.compare_center_patch)
|
115
train.py
115
train.py
@ -12,7 +12,9 @@
|
||||
import os
|
||||
import torch
|
||||
from random import randint
|
||||
from utils.loss_utils import l1_loss, ssim
|
||||
from utils.loss_utils import l1_loss, ssim, InvDepthSmoothnessLoss, laplacian_pyramid_loss
|
||||
from utils.bundle_utils import cluster_cameras, bundle_start_index_generator, adaptive_cluster
|
||||
from utils.aug_utils import *
|
||||
from gaussian_renderer import render, network_gui
|
||||
import sys
|
||||
from scene import Scene, GaussianModel
|
||||
@ -40,7 +42,20 @@ try:
|
||||
except:
|
||||
SPARSE_ADAM_AVAILABLE = False
|
||||
|
||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
||||
def training(dataset,
|
||||
opt,
|
||||
pipe,
|
||||
testing_iterations,
|
||||
saving_iterations,
|
||||
checkpoint_iterations,
|
||||
checkpoint,
|
||||
debug_from,
|
||||
camera_order,
|
||||
bundle_training,
|
||||
enable_ds_lap,
|
||||
lambda_ds,
|
||||
lambda_lap,
|
||||
):
|
||||
|
||||
if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
|
||||
sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
|
||||
@ -68,6 +83,11 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
ema_loss_for_log = 0.0
|
||||
ema_Ll1depth_for_log = 0.0
|
||||
|
||||
if bundle_training:
|
||||
sorted_keys = cluster_cameras(dataset.source_path, camera_order)
|
||||
start_indices, cluster_sizes = bundle_start_index_generator(sorted_keys, 20)
|
||||
n_interval = 0
|
||||
|
||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||
first_iter += 1
|
||||
for iteration in range(first_iter, opt.iterations + 1):
|
||||
@ -94,13 +114,20 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
if iteration % 1000 == 0:
|
||||
gaussians.oneupSHdegree()
|
||||
|
||||
# Pick a random Camera
|
||||
if not viewpoint_stack:
|
||||
viewpoint_stack = scene.getTrainCameras().copy()
|
||||
viewpoint_indices = list(range(len(viewpoint_stack)))
|
||||
rand_idx = randint(0, len(viewpoint_indices) - 1)
|
||||
viewpoint_cam = viewpoint_stack.pop(rand_idx)
|
||||
vind = viewpoint_indices.pop(rand_idx)
|
||||
if bundle_training and iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and cluster_sizes[n_interval] < 160 and iteration % 100 > 80:
|
||||
densification_viewpoint_stack = scene.getTrainCameras().copy()
|
||||
selected_idx = adaptive_cluster(start_indices[n_interval], sorted_keys, cluster_sizes[n_interval])
|
||||
if selected_idx >= len(densification_viewpoint_stack):
|
||||
selected_idx = selected_idx % len(densification_viewpoint_stack)
|
||||
viewpoint_cam = densification_viewpoint_stack[selected_idx]
|
||||
|
||||
else:
|
||||
if not viewpoint_stack:
|
||||
viewpoint_stack = scene.getTrainCameras().copy()
|
||||
viewpoint_indices = list(range(len(viewpoint_stack)))
|
||||
rand_idx = randint(0, len(viewpoint_indices) - 1)
|
||||
viewpoint_cam = viewpoint_stack.pop(rand_idx)
|
||||
vind = viewpoint_indices.pop(rand_idx)
|
||||
|
||||
# Render
|
||||
if (iteration - 1) == debug_from:
|
||||
@ -123,7 +150,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
else:
|
||||
ssim_value = ssim(image, gt_image)
|
||||
|
||||
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
|
||||
if enable_ds_lap:
|
||||
ds_loss = InvDepthSmoothnessLoss()(render_pkg["depth"], image)
|
||||
lap_loss = laplacian_pyramid_loss(image.unsqueeze(0), gt_image.unsqueeze(0))
|
||||
else:
|
||||
ds_loss = 0.0
|
||||
lap_loss = 0.0
|
||||
|
||||
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + lambda_ds * ds_loss + lambda_lap * lap_loss
|
||||
|
||||
# Depth regularization
|
||||
Ll1depth_pure = 0.0
|
||||
@ -155,7 +189,24 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
progress_bar.close()
|
||||
|
||||
# Log and save
|
||||
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp)
|
||||
training_report(tb_writer,
|
||||
iteration,
|
||||
Ll1,
|
||||
loss,
|
||||
l1_loss,
|
||||
iter_start.elapsed_time(iter_end),
|
||||
testing_iterations,
|
||||
scene,
|
||||
render,
|
||||
(pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp),
|
||||
dataset.train_test_exp,
|
||||
1.0 - ssim_value,
|
||||
ds_loss,
|
||||
lap_loss,
|
||||
lambda_ds,
|
||||
lambda_lap
|
||||
)
|
||||
|
||||
if (iteration in saving_iterations):
|
||||
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
||||
scene.save(iteration)
|
||||
@ -169,6 +220,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
||||
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
||||
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
||||
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
|
||||
n_interval += 1
|
||||
|
||||
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
||||
gaussians.reset_opacity()
|
||||
@ -211,10 +263,30 @@ def prepare_output_and_logger(args):
|
||||
print("Tensorboard not available: not logging progress")
|
||||
return tb_writer
|
||||
|
||||
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
|
||||
def training_report(tb_writer,
|
||||
iteration,
|
||||
Ll1,
|
||||
loss,
|
||||
l1_loss,
|
||||
elapsed,
|
||||
testing_iterations,
|
||||
scene : Scene,
|
||||
renderFunc,
|
||||
renderArgs,
|
||||
train_test_exp,
|
||||
ssim_loss,
|
||||
ds_loss,
|
||||
lap_loss,
|
||||
lambda_ds,
|
||||
lambda_lap):
|
||||
if tb_writer:
|
||||
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/ssim_loss', ssim_loss.item(), iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/ds_loss', ds_loss.item(), iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/lap_loss', lap_loss.item(), iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/lambda_ds', lambda_ds, iteration)
|
||||
tb_writer.add_scalar('train_loss_patches/lambda_lap', lambda_lap, iteration)
|
||||
tb_writer.add_scalar('iter_time', elapsed, iteration)
|
||||
|
||||
# Report test and samples of training set
|
||||
@ -267,6 +339,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--disable_viewer', action='store_true', default=False)
|
||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||
parser.add_argument("--bundle_training", action='store_true', default=False)
|
||||
parser.add_argument("--camera_order", type=str, default='covisibility')
|
||||
parser.add_argument("--enable_ds_lap", action='store_true', default=False)
|
||||
parser.add_argument("--lambda_ds", type=float, default=0.0)
|
||||
parser.add_argument("--lambda_lap", type=float, default=0.0)
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
args.save_iterations.append(args.iterations)
|
||||
|
||||
@ -279,7 +356,19 @@ if __name__ == "__main__":
|
||||
if not args.disable_viewer:
|
||||
network_gui.init(args.ip, args.port)
|
||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
||||
training(lp.extract(args),
|
||||
op.extract(args),
|
||||
pp.extract(args),
|
||||
args.test_iterations,
|
||||
args.save_iterations,
|
||||
args.checkpoint_iterations,
|
||||
args.start_checkpoint,
|
||||
args.debug_from,
|
||||
args.bundle_training,
|
||||
args.camera_order,
|
||||
args.enable_ds_lap,
|
||||
args.lambda_ds,
|
||||
args.lambda_lap)
|
||||
|
||||
# All done
|
||||
print("\nTraining complete.")
|
||||
|
387
utils/aug_utils.py
Normal file
387
utils/aug_utils.py
Normal file
@ -0,0 +1,387 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.neighbors import BallTree
|
||||
from colmap.scripts.python.read_write_model import *
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
from sklearn.decomposition import PCA
|
||||
import rtree
|
||||
from shapely.geometry import Point, box
|
||||
from collections import defaultdict
|
||||
from sklearn.decomposition import PCA
|
||||
from colmap_utils import compute_extrinsics, get_colmap_data
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
class Node:
|
||||
def __init__(self, x0, y0, width, height):
|
||||
self.x0 = x0
|
||||
self.y0 = y0
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.children = []
|
||||
self.unoccupied = True
|
||||
self.sampled_point_uv = None
|
||||
self.sampled_point_rgb = None
|
||||
self.sampled_point_depth = None
|
||||
self.depth_interpolated = None
|
||||
self.sampled_point_neighbours_indices = None
|
||||
self.inference_count = 0
|
||||
self.rejection_count = 0
|
||||
self.sampled_point_world = None
|
||||
self.matching_log = {}
|
||||
self.points3d_indices = []
|
||||
self.points3d_depths = []
|
||||
self.points3d_rgb = []
|
||||
self.sampled_point_neighbours_uv = None
|
||||
|
||||
def get_error(self, img):
|
||||
# Calculate the standard deviation of the region as an error metric
|
||||
region = img[self.y0:self.y0+self.height, self.x0:self.x0+self.width]
|
||||
return np.std(region)
|
||||
|
||||
def recursive_subdivide(node, threshold, min_pixel_size, img):
|
||||
if node.get_error(img) <= threshold:
|
||||
return
|
||||
|
||||
w_1 = node.width // 2
|
||||
w_2 = node.width - w_1
|
||||
h_1 = node.height // 2
|
||||
h_2 = node.height - h_1
|
||||
|
||||
if w_1 <= min_pixel_size or h_1 <= min_pixel_size:
|
||||
return
|
||||
|
||||
# Create four children nodes
|
||||
x1 = Node(node.x0, node.y0, w_1, h_1) # top left
|
||||
recursive_subdivide(x1, threshold, min_pixel_size, img)
|
||||
|
||||
x2 = Node(node.x0, node.y0 + h_1, w_1, h_2) # bottom left
|
||||
recursive_subdivide(x2, threshold, min_pixel_size, img)
|
||||
|
||||
x3 = Node(node.x0 + w_1, node.y0, w_2, h_1) # top right
|
||||
recursive_subdivide(x3, threshold, min_pixel_size, img)
|
||||
|
||||
x4 = Node(node.x0 + w_1, node.y0 + h_1, w_2, h_2) # bottom right
|
||||
recursive_subdivide(x4, threshold, min_pixel_size, img)
|
||||
|
||||
node.children = [x1, x2, x3, x4]
|
||||
|
||||
def quadtree_decomposition(img, threshold, min_pixel_size):
|
||||
root = Node(0, 0, img.shape[1], img.shape[0])
|
||||
recursive_subdivide(root, threshold, min_pixel_size, img)
|
||||
return root
|
||||
|
||||
def gather_leaf_nodes(node, leaf_nodes):
|
||||
if not node.children:
|
||||
leaf_nodes.append(node)
|
||||
else:
|
||||
for child in node.children:
|
||||
gather_leaf_nodes(child, leaf_nodes)
|
||||
|
||||
def find_leaf_node(root, pixel_x, pixel_y):
|
||||
if not (root.x0 <= pixel_x < root.x0 + root.width and
|
||||
root.y0 <= pixel_y < root.y0 + root.height):
|
||||
return None # 픽셀이 루트 노드의 범위를 벗어난 경우
|
||||
|
||||
current = root
|
||||
while current.children:
|
||||
for child in current.children:
|
||||
if (child.x0 <= pixel_x < child.x0 + child.width and
|
||||
child.y0 <= pixel_y < child.y0 + child.height):
|
||||
current = child
|
||||
break
|
||||
else:
|
||||
# 적절한 자식 노드를 찾지 못한 경우
|
||||
return current
|
||||
|
||||
return current
|
||||
|
||||
def draw_quadtree(node, ax):
|
||||
if not node.children:
|
||||
rect = plt.Rectangle((node.x0, node.y0), node.width, node.height, fill=False, edgecolor='red')
|
||||
ax.add_patch(rect)
|
||||
else:
|
||||
for child in node.children:
|
||||
draw_quadtree(child, ax)
|
||||
|
||||
def pixel_to_3d(point, depth, intrinsics):
|
||||
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||
|
||||
x, y = point
|
||||
z = depth
|
||||
x_3d = (x - cx) * z / fx
|
||||
y_3d = (y - cy) * z / fy
|
||||
return np.array([x_3d, y_3d, z])
|
||||
|
||||
def project_3d_to_2d(points3d, intrinsics, extrinsics):
|
||||
"""
|
||||
Project a 3D point to 2D using camera intrinsics and extrinsics.
|
||||
|
||||
Args:
|
||||
- point_3d: 3D point as a numpy array (x, y, z).
|
||||
- intrinsics: Camera intrinsics matrix (3x3).
|
||||
- extrinsics: Camera extrinsics matrix (4x4).
|
||||
|
||||
Returns:
|
||||
- 2D point as a tuple (x, y) in pixel coordinates.
|
||||
"""
|
||||
point_3d_homogeneous = np.hstack((points3d, np.ones((points3d.shape[0], 1)))) # Convert to homogeneous coordinates
|
||||
point_camera = extrinsics @ point_3d_homogeneous.T
|
||||
|
||||
point_image_homogeneous = intrinsics @ point_camera
|
||||
point_2d = point_image_homogeneous[:2] / point_image_homogeneous[2]
|
||||
|
||||
return point_2d.T, point_image_homogeneous[2:].T
|
||||
|
||||
def check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices):
|
||||
rect_index = rtree.index.Index()
|
||||
rectangles = [((node.x0, node.y0, node.x0 + node.width, node.y0 + node.height), i) for i, node in enumerate(leaf_nodes)]
|
||||
for rect, i in rectangles:
|
||||
rect_index.insert(i, rect)
|
||||
rectangle_indices = []
|
||||
for i, (x, y) in enumerate(points3d_pix):
|
||||
if near_culled_indices[i]:
|
||||
continue
|
||||
point = Point(x, y)
|
||||
matches = list(rect_index.intersection((x, y, x, y)))
|
||||
for match in matches:
|
||||
if box(*rectangles[match][0]).contains(point):
|
||||
rectangle_indices.append([match, i])
|
||||
|
||||
for index, point3d_idx in rectangle_indices:
|
||||
leaf_nodes[index].unoccupied = False
|
||||
leaf_nodes[index].points3d_indices.append(point3d_idx)
|
||||
leaf_nodes[index].points3d_depths.append(points3d_depth[point3d_idx])
|
||||
leaf_nodes[index].points3d_rgb.append(points3d_rgb[point3d_idx])
|
||||
return leaf_nodes
|
||||
|
||||
def compute_normal_vector(points3d_cameraframe, points3d_depth, indices):
|
||||
points3d_cameraframe = np.concatenate((points3d_cameraframe[:,:2]*points3d_depth.reshape(-1, 1), points3d_depth.reshape(-1, 1)), axis=1)
|
||||
p1 = points3d_cameraframe[indices[:, 0]]
|
||||
p2 = points3d_cameraframe[indices[:, 1]]
|
||||
p3 = points3d_cameraframe[indices[:, 2]]
|
||||
v1 = p2 - p1
|
||||
v2 = p3 - p1
|
||||
|
||||
normal = np.cross(v1, v2)
|
||||
a, b, c = normal[:, 0], normal[:, 1], normal[:, 2]
|
||||
d = -a * p1[:, 0] - b * p1[:, 1] - c * p1[:, 2]
|
||||
|
||||
return a, b, c, d
|
||||
|
||||
def compute_depth(sample_points_cameraframe, a, b, c, d, cosine_threshold=0.01):
|
||||
direction_vectors = sample_points_cameraframe
|
||||
t = -d.reshape(-1,1) / np.sum(np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) * direction_vectors, axis=1).reshape(-1,1)
|
||||
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
|
||||
cosine_similarity = np.abs(np.sum(normal_vectors * direction_vectors, axis=1)).reshape(-1,1) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1) / np.linalg.norm(direction_vectors, axis=1).reshape(-1,1)
|
||||
rejected_indices = cosine_similarity < cosine_threshold
|
||||
depth = t.reshape(-1,1)*direction_vectors[:, 2:]
|
||||
return depth, rejected_indices.reshape(-1)
|
||||
|
||||
def find_perpendicular_triangle_indices(a, b, c, cosine_threshold=0.01):
|
||||
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
|
||||
camera_normal = np.array([0, 0, 1]).reshape(1, 3)
|
||||
cosine_similarity = np.dot(normal_vectors, camera_normal.T) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1)
|
||||
cosine_culled_indices = -cosine_threshold < cosine_similarity < cosine_threshold
|
||||
return cosine_culled_indices
|
||||
|
||||
|
||||
def find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera, rotations_image, translations_image, depth_cutoff = 2.):
|
||||
sampled_points = []
|
||||
for node in leaf_nodes:
|
||||
if node.unoccupied:
|
||||
sampled_points.append(node.sampled_point_uv)
|
||||
sampled_points = np.array(sampled_points)
|
||||
|
||||
# Later we should store the 3D points' original indices
|
||||
# This is because the 3D points cannot be masked during NN search.
|
||||
# Near_culled_indices indicates the indices of the 3D points that are outside the camera frustum or have negative depth.
|
||||
original_indices = np.arange(points3d_pix.shape[0])
|
||||
original_indices = original_indices[~near_culled_indices]
|
||||
|
||||
# Only infrustum 3D points are used for depth interpolation.
|
||||
points3d_pix = points3d_pix[~near_culled_indices]
|
||||
points3d_depth = points3d_depth[~near_culled_indices]
|
||||
# Construct a BallTree for nearest neighbor search.
|
||||
tree = BallTree(points3d_pix, leaf_size=40)
|
||||
# Query the nearest 3 neighbors for each sampled point.
|
||||
distances, indices = tree.query(sampled_points, k=3)
|
||||
# Inverse the camera intrinsic matrix for generating the camera rays.
|
||||
inverse_intrinsics = np.linalg.inv(intrinsics_camera)
|
||||
# Generate the camera rays in camera coordinate system.
|
||||
sampled_points_homogeneous = np.concatenate((sampled_points, np.ones((sampled_points.shape[0], 1))), axis=1)
|
||||
sampled_points_cameraframe = (inverse_intrinsics @ sampled_points_homogeneous.T).T
|
||||
# Transform the 3D points to the camera coordinate system with precomputed depths.
|
||||
points3d_pix_homogeneous = np.concatenate((points3d_pix, np.ones((points3d_pix.shape[0], 1))), axis=1)
|
||||
points3d_cameraframe = (inverse_intrinsics @ points3d_pix_homogeneous.T).T
|
||||
# Compute the normal vector and the distance of the triangle formed by the nearest 3 neighbors.
|
||||
a, b, c, d = compute_normal_vector(points3d_cameraframe, points3d_depth, indices)
|
||||
# Compute the depth of the sampled points from the normal vector.
|
||||
sampled_points_depth, cosine_culled_indices = compute_depth(sampled_points_cameraframe, a, b, c, d)
|
||||
# Reject the points that have negative depth, are NaN, or the plane constructed by the nearest 3 neighbors is parallel to the camera ray.
|
||||
depth_rejected_indices = (sampled_points_depth < depth_cutoff).reshape(-1) | np.isnan(sampled_points_cameraframe).any(axis=1) | cosine_culled_indices.reshape(-1)
|
||||
# Transform the sampled points to the world coordinate system.
|
||||
sampled_points_world = transform_camera_to_world(np.concatenate((sampled_points_cameraframe[:,:2]*sampled_points_depth.reshape(-1,1), sampled_points_depth.reshape(-1,1)), axis=1), rotations_image, translations_image)
|
||||
sampled_points_world = sampled_points_world[:, :3]
|
||||
|
||||
augmented_count = 0
|
||||
depth_index = 0
|
||||
for node in leaf_nodes:
|
||||
if node.unoccupied:
|
||||
if depth_rejected_indices[depth_index]:
|
||||
node.depth_interpolated = False
|
||||
else:
|
||||
node.sampled_point_depth = sampled_points_depth[depth_index]
|
||||
node.sampled_point_world = sampled_points_world[depth_index]
|
||||
node.depth_interpolated = True
|
||||
node.sampled_point_neighbours_indices = original_indices[indices[depth_index]]
|
||||
node.sampled_point_neighbours_uv = points3d_pix[indices[depth_index]]
|
||||
augmented_count += 1
|
||||
depth_index += 1
|
||||
|
||||
return leaf_nodes, augmented_count
|
||||
|
||||
def transform_camera_to_world(sampled_points_cameraframe, rotations_image, translations_image):
|
||||
extrinsics_image = np.concatenate((np.array(rotations_image), np.array(translations_image).reshape(3,1)), axis=1)
|
||||
extrinsics_4x4 = np.concatenate((extrinsics_image, np.array([0, 0, 0, 1]).reshape(1, 4)), axis=0)
|
||||
sampled_points_cameraframe_homogeneous = np.concatenate((sampled_points_cameraframe, np.ones((sampled_points_cameraframe.shape[0], 1))), axis=1)
|
||||
sampled_points_worldframe = np.linalg.inv(extrinsics_4x4) @ sampled_points_cameraframe_homogeneous.T
|
||||
|
||||
return sampled_points_worldframe.T
|
||||
|
||||
def pixelwise_rgb_diff(point_a, point_b, threshold=0.3):
|
||||
return np.linalg.norm(point_a - point_b) / 255. > threshold
|
||||
|
||||
def image_quadtree_augmentation(image_key,
|
||||
image_dir,
|
||||
colmap_cameras,
|
||||
colmap_images,
|
||||
colmap_points3D,
|
||||
points3d, points3d_rgb,
|
||||
intrinsics_camera,
|
||||
rotations_image,
|
||||
translations_image,
|
||||
quadtree_std_threshold=7,
|
||||
quadtree_min_pixel_size=5,
|
||||
visibility_aware_culling=False):
|
||||
image_name = colmap_images[image_key].name
|
||||
image_path = os.path.join(image_dir, image_name)
|
||||
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
||||
image_grayscale = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
quadtree_root = quadtree_decomposition(image_grayscale, quadtree_std_threshold, quadtree_min_pixel_size)
|
||||
|
||||
leaf_nodes = []
|
||||
gather_leaf_nodes(quadtree_root, leaf_nodes)
|
||||
# Project 3D points onto the image plane.
|
||||
points3d_pix, points3d_depth = project_3d_to_2d(points3d, intrinsics_camera[colmap_images[image_key].camera_id], np.concatenate((np.array(rotations_image[image_key]), np.array(translations_image[image_key]).reshape(3,1)), axis=1))
|
||||
# Cull the points that are outside the camera frustum and have negative depth.
|
||||
near_culled_indices = (points3d_pix[:, 0] < 0) | (points3d_pix[:, 0] >= image.shape[1]) | (points3d_pix[:, 1] < 0) | (points3d_pix[:, 1] >= image.shape[0]) | (points3d_depth.reshape(-1) < 0) | np.isnan(points3d_pix).any(axis=1)
|
||||
points3d_pix_rgb_diff = np.zeros(points3d_pix.shape[0], dtype=np.bool_)
|
||||
if visibility_aware_culling:
|
||||
for i, point in enumerate(points3d_pix):
|
||||
if near_culled_indices[i]:
|
||||
continue
|
||||
point_a = image[int(point[1]), int(point[0])]
|
||||
point_b = points3d_rgb[i]
|
||||
points3d_pix_rgb_diff[i] = pixelwise_rgb_diff(point_a, point_b)
|
||||
near_culled_indices = near_culled_indices | points3d_pix_rgb_diff.reshape(-1)
|
||||
# Check every node in the quadtree that contains projected 3D points.
|
||||
leaf_nodes = check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices)
|
||||
# Sample points from the leaf nodes that are not occupied by projected 3D points.
|
||||
sampled_points = []
|
||||
sampled_points_rgb = []
|
||||
for node in leaf_nodes:
|
||||
if node.unoccupied:
|
||||
node.sampled_point_uv = np.array([node.x0, node.y0]) + np.random.sample(2) * np.array([node.width, node.height])
|
||||
node.sampled_point_rgb = image[int(node.sampled_point_uv[1]), int(node.sampled_point_uv[0])]
|
||||
sampled_points.append(node.sampled_point_uv)
|
||||
sampled_points_rgb.append(node.sampled_point_rgb)
|
||||
sampled_points = np.array(sampled_points)
|
||||
sampled_points_rgb = np.array(sampled_points_rgb)
|
||||
# Interpolate the depth of the sampled points from the nearest 3D points.
|
||||
leaf_nodes, augmented_count = find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera[colmap_images[image_key].camera_id], rotations_image[image_key], translations_image[image_key])
|
||||
|
||||
return quadtree_root, augmented_count
|
||||
|
||||
def transform_sample_3d(image_key, root, colmap_images, colmap_cameras, intrinsics_camera, rotations_image, translations_image):
|
||||
# Gather leaf nodes and transform the sampled 2D points to the 3D world coordinates.
|
||||
leaf_nodes = []
|
||||
gather_leaf_nodes(root, leaf_nodes)
|
||||
sample_points_imageframe = []
|
||||
sample_points_depth = []
|
||||
sample_points_rgb = []
|
||||
for node in leaf_nodes:
|
||||
if node.unoccupied:
|
||||
if node.depth_interpolated:
|
||||
sample_points_imageframe.append(node.sampled_point_uv)
|
||||
sample_points_depth.append(node.sampled_point_depth)
|
||||
sample_points_rgb.append(node.sampled_point_rgb)
|
||||
sample_points_imageframe = np.array(sample_points_imageframe)
|
||||
sample_points_depth = np.array(sample_points_depth)
|
||||
sample_points_rgb = np.array(sample_points_rgb)
|
||||
sample_points_cameraframe = (np.linalg.inv(intrinsics_camera[colmap_images[image_key].camera_id]) @ np.concatenate((sample_points_imageframe, np.ones((sample_points_imageframe.shape[0], 1))), axis=1).T).T
|
||||
sample_points_cameraframe = np.concatenate((sample_points_cameraframe[:,:2]*sample_points_depth.reshape(-1,1), sample_points_depth.reshape(-1, 1)), axis=1)
|
||||
sample_points_worldframe = transform_camera_to_world(sample_points_cameraframe, rotations_image[image_key], translations_image[image_key])
|
||||
return sample_points_worldframe[:,:-1], sample_points_rgb
|
||||
|
||||
def write_points3D_colmap_binary(points3D, xyz, rgb, file_path):
|
||||
with open(file_path, "wb") as fid:
|
||||
write_next_bytes(fid, len(points3D) + len(xyz), "Q")
|
||||
for j, (_, pt) in enumerate(points3D.items()):
|
||||
write_next_bytes(fid, pt.id, "Q")
|
||||
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
|
||||
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
|
||||
write_next_bytes(fid, pt.error, "d")
|
||||
track_length = pt.image_ids.shape[0]
|
||||
write_next_bytes(fid, track_length, "Q")
|
||||
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
|
||||
write_next_bytes(fid, [image_id, point2D_id], "ii")
|
||||
id = points3D[max(points3D.keys())].id + 1
|
||||
print("starts from id=", id)
|
||||
for i in range(len(xyz)):
|
||||
write_next_bytes(fid, id + i, "Q")
|
||||
write_next_bytes(fid, xyz[i].tolist(), "ddd")
|
||||
write_next_bytes(fid, rgb[i].tolist(), "BBB")
|
||||
write_next_bytes(fid, 0.0, "d")
|
||||
write_next_bytes(fid, 1, "Q")
|
||||
write_next_bytes(fid, [0, 0], "ii")
|
||||
return i + id
|
||||
|
||||
def compare_local_texture(patch1, patch2):
|
||||
"""
|
||||
Compare two patches using a Gaussian kernel.
|
||||
If the patch size is not 3x3, apply zero padding.
|
||||
"""
|
||||
# Define a 3x3 Gaussian kernel (sigma=1.0)
|
||||
gaussian_kernel = np.array([
|
||||
[0.077847, 0.123317, 0.077847],
|
||||
[0.123317, 0.195346, 0.123317],
|
||||
[0.077847, 0.123317, 0.077847]
|
||||
])
|
||||
|
||||
# Check the patch size and apply zero padding if it is not 3x3.
|
||||
def pad_to_3x3(patch):
|
||||
if patch.shape[:2] != (3,3):
|
||||
padded = np.zeros((3,3,3) if len(patch.shape)==3 else (3,3))
|
||||
h, w = patch.shape[:2]
|
||||
y_start = (3-h)//2
|
||||
x_start = (3-w)//2
|
||||
padded[y_start:y_start+h, x_start:x_start+w] = patch
|
||||
return padded
|
||||
return patch
|
||||
|
||||
patch1 = pad_to_3x3(patch1)
|
||||
patch2 = pad_to_3x3(patch2)
|
||||
|
||||
if len(patch1.shape) == 3: # RGB image
|
||||
# Apply weights to each channel and calculate the difference.
|
||||
weighted_diff = np.sum([
|
||||
np.sum(gaussian_kernel * (patch1[:,:,c] - patch2[:,:,c])**2)
|
||||
for c in range(3)
|
||||
])
|
||||
else: # Grayscale image
|
||||
weighted_diff = np.sum(gaussian_kernel * (patch1 - patch2)**2)
|
||||
|
||||
return np.sqrt(weighted_diff)/255.
|
156
utils/bundle_utils.py
Normal file
156
utils/bundle_utils.py
Normal file
@ -0,0 +1,156 @@
|
||||
import random
|
||||
import os
|
||||
from colmap.scripts.python.read_write_model import *
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from colmap_utils import compute_extrinsics, compute_intrinsics, get_colmap_data
|
||||
import cv2
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
|
||||
def build_covisibility_matrix(images, points3D):
|
||||
n_images = len(images)
|
||||
covisibility_matrix = np.zeros((n_images, n_images))
|
||||
|
||||
id_to_idx = {img_id: idx for idx, img_id in enumerate(images.keys())}
|
||||
idx_to_id = {idx: img_id for img_id, idx in id_to_idx.items()}
|
||||
|
||||
for point3D in points3D.values():
|
||||
image_ids = point3D.image_ids
|
||||
for i in range(len(image_ids)):
|
||||
for j in range(i+1, len(image_ids)):
|
||||
id1, id2 = image_ids[i], image_ids[j]
|
||||
idx1, idx2 = id_to_idx[id1], id_to_idx[id2]
|
||||
covisibility_matrix[idx1, idx2] += 1
|
||||
covisibility_matrix[idx2, idx1] += 1
|
||||
|
||||
return covisibility_matrix, id_to_idx, idx_to_id
|
||||
|
||||
def create_covisibility_graph(covisibility_matrix, idx_to_id):
|
||||
graph = defaultdict(dict)
|
||||
for i in range(len(covisibility_matrix)):
|
||||
for j in range(len(covisibility_matrix)):
|
||||
if i != j and covisibility_matrix[i,j] > 0:
|
||||
id1 = idx_to_id[i]
|
||||
id2 = idx_to_id[j]
|
||||
graph[id1][id2] = covisibility_matrix[i,j]
|
||||
|
||||
return graph
|
||||
|
||||
def create_sequence_from_covisibility_graph(covisibility_graph, min_covisibility=20):
|
||||
if not covisibility_graph:
|
||||
print("No covisibility graph found")
|
||||
return []
|
||||
|
||||
start_node = max(covisibility_graph.keys(),
|
||||
key=lambda k: sum(1 for v in covisibility_graph[k].values() if v >= min_covisibility))
|
||||
|
||||
visited = set([start_node])
|
||||
sequence = [start_node]
|
||||
current = start_node
|
||||
|
||||
while len(sequence) < len(covisibility_graph):
|
||||
next_node = None
|
||||
max_covisibility = -1
|
||||
|
||||
for neighbor, covisibility in covisibility_graph[current].items():
|
||||
if neighbor not in visited and covisibility > max_covisibility and covisibility >= min_covisibility:
|
||||
next_node = neighbor
|
||||
max_covisibility = covisibility
|
||||
|
||||
if next_node is None:
|
||||
for node in covisibility_graph:
|
||||
if node not in visited:
|
||||
for seq_node in sequence:
|
||||
covisibility = covisibility_graph[node].get(seq_node, 0)
|
||||
if covisibility > max_covisibility and covisibility >= min_covisibility:
|
||||
max_covisibility = covisibility
|
||||
next_node = node
|
||||
|
||||
if next_node is None:
|
||||
next_node = min(set(covisibility_graph.keys()) - visited)
|
||||
|
||||
current = next_node
|
||||
visited.add(current)
|
||||
sequence.append(current)
|
||||
|
||||
return sequence
|
||||
|
||||
def cluster_cameras(model_path, camera_order):
|
||||
colmap_path = os.path.join(model_path, 'sparse/0')
|
||||
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
|
||||
if camera_order == 'covisibility':
|
||||
covisibility_matrix, id_to_idx, idx_to_id = build_covisibility_matrix(colmap_images, colmap_points3D)
|
||||
covisibility_graph = create_covisibility_graph(covisibility_matrix, idx_to_id)
|
||||
covisibility_sequence = create_sequence_from_covisibility_graph(covisibility_graph)
|
||||
|
||||
image_id_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
|
||||
image_id_name_sorted = sorted(image_id_name, key=lambda x: x[1])
|
||||
test_id = []
|
||||
train_id = []
|
||||
count = 0
|
||||
train_only_idx = []
|
||||
for i, (id, name) in enumerate(image_id_name_sorted):
|
||||
if i % 8 == 0:
|
||||
test_id.append(id)
|
||||
else:
|
||||
train_id.append(id)
|
||||
train_only_idx.append(count)
|
||||
count+=1
|
||||
|
||||
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||
|
||||
train_only_visibility_idx = []
|
||||
for id in covisibility_sequence:
|
||||
if id in train_id:
|
||||
train_only_visibility_idx.append(id)
|
||||
train_only_visibility_idx = np.array(train_only_visibility_idx)
|
||||
sorted_keys = train_only_visibility_idx # sorted_indices 대신 sorted_keys로 할당
|
||||
|
||||
elif camera_order == 'PCA':
|
||||
image_idx_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
|
||||
image_idx_name_sorted = sorted(image_idx_name, key=lambda x: x[1])
|
||||
test_idx = []
|
||||
train_idx = []
|
||||
for i, (idx, name) in enumerate(image_idx_name_sorted):
|
||||
if i % 8 == 0:
|
||||
test_idx.append(idx)
|
||||
else:
|
||||
train_idx.append(idx)
|
||||
|
||||
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||
|
||||
cam_center = []
|
||||
key = []
|
||||
for idx in train_idx:
|
||||
cam_center.append((-rotations_image[idx].T @ translations_image[idx].reshape(3,1)))
|
||||
key.append(idx)
|
||||
|
||||
cam_center = np.array(cam_center)[:,:,0]
|
||||
pca = PCA(n_components=2)
|
||||
cam_center_2d = pca.fit_transform(cam_center)
|
||||
|
||||
center_cam_center = np.mean(cam_center_2d, axis=0)
|
||||
centered_cam_center = cam_center_2d - center_cam_center
|
||||
angles = np.arctan2(centered_cam_center[:, 1], centered_cam_center[:, 0])
|
||||
sorted_indices = np.argsort(angles)
|
||||
sorted_cam_centers = cam_center_2d[sorted_indices]
|
||||
sorted_keys = np.array(key)[sorted_indices]
|
||||
|
||||
return sorted_keys
|
||||
|
||||
def bundle_start_index_generator(sorted_keys, initial_interval):
|
||||
start_indices = []
|
||||
cluster_sizes = []
|
||||
for i in range(200):
|
||||
start_indices.append((initial_interval*i)%len(sorted_keys))
|
||||
cluster_sizes.append(initial_interval)
|
||||
|
||||
return start_indices, cluster_sizes
|
||||
|
||||
def adaptive_cluster(start_idx, sorted_keys, cluster_size = 40, offset = 0):
|
||||
idx = start_idx
|
||||
indices = [sorted_keys[index % len(sorted_keys)] for index in range(idx, idx + cluster_size)]
|
||||
random_index = random.choice(indices)
|
||||
return random_index
|
78
utils/colmap_utils.py
Normal file
78
utils/colmap_utils.py
Normal file
@ -0,0 +1,78 @@
|
||||
from colmap.scripts.python.read_write_model import *
|
||||
|
||||
def get_colmap_data(dataset_path):
|
||||
"""
|
||||
Load COLMAP data from the given dataset path.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset directory containing images, sparse/0, and dense/0 folders.
|
||||
|
||||
Returns:
|
||||
images: colmap image infos
|
||||
points3D: colmap 3D points
|
||||
cameras: colmap camera infos
|
||||
"""
|
||||
images = read_images_binary(os.path.join(dataset_path, 'images.bin'))
|
||||
points3D = read_points3D_binary(os.path.join(dataset_path, 'points3D.bin'))
|
||||
cameras = read_cameras_binary(os.path.join(dataset_path, 'cameras.bin'))
|
||||
return images, points3D, cameras
|
||||
|
||||
def quaternion_rotation_matrix(qw, qx, qy, qz):
|
||||
"""
|
||||
Convert a quaternion to a rotation matrix.
|
||||
colmap uses wxyz order for quaternions.
|
||||
"""
|
||||
# First row of the rotation matrix
|
||||
r00 = 2 * (qw * qw + qx * qx) - 1
|
||||
r01 = 2 * (qx * qy - qw * qz)
|
||||
r02 = 2 * (qw * qy + qx * qz)
|
||||
|
||||
# Second row of the rotation matrix
|
||||
r10 = 2 * (qx * qy + qw * qz)
|
||||
r11 = 2 * (qw * qw + qy * qy) - 1
|
||||
r12 = 2 * (qy * qz - qw * qx)
|
||||
|
||||
# Third row of the rotation matrix
|
||||
r20 = 2 * (qz * qx - qw * qy)
|
||||
r21 = 2 * (qz * qy + qw * qx)
|
||||
r22 = 2 * (qw * qw + qz * qz) - 1
|
||||
|
||||
# 3x3 rotation matrix
|
||||
rot_matrix = np.array([[r00, r01, r02],
|
||||
[r10, r11, r12],
|
||||
[r20, r21, r22]])
|
||||
|
||||
return rot_matrix
|
||||
|
||||
def compute_intrinsic_matrix(fx, fy, cx, cy, image_width, image_height):
|
||||
sx = cx/(image_width/2)
|
||||
sy = cy/(image_height/2)
|
||||
intrinsic_matrix = np.array([[fx/sx, 0, cx/sx], [0, fy/sy, cy/sy], [0, 0, 1]])
|
||||
return intrinsic_matrix
|
||||
|
||||
def compute_intrinsics(colmap_cameras, image_width, image_height):
|
||||
intrinsics = {}
|
||||
for cam_key in colmap_cameras.keys():
|
||||
intrinsic_parameters = colmap_cameras[cam_key].params
|
||||
assert colmap_cameras[cam_key].model == 'PINHOLE'
|
||||
intrinsic = compute_intrinsic_matrix(intrinsic_parameters[0],
|
||||
intrinsic_parameters[1],
|
||||
intrinsic_parameters[2],
|
||||
intrinsic_parameters[3],
|
||||
image_width,
|
||||
image_height)
|
||||
intrinsics[cam_key] = intrinsic
|
||||
return intrinsics
|
||||
|
||||
def compute_extrinsics(colmap_images):
|
||||
rotations = {}
|
||||
translations = {}
|
||||
for image_key in colmap_images.keys():
|
||||
rotation = quaternion_rotation_matrix(colmap_images[image_key].qvec[0],
|
||||
colmap_images[image_key].qvec[1],
|
||||
colmap_images[image_key].qvec[2],
|
||||
colmap_images[image_key].qvec[3])
|
||||
translation = colmap_images[image_key].tvec
|
||||
rotations[image_key] = rotation
|
||||
translations[image_key] = translation
|
||||
return rotations, translations
|
@ -10,6 +10,7 @@
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from math import exp
|
||||
@ -17,6 +18,7 @@ try:
|
||||
from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward
|
||||
except:
|
||||
pass
|
||||
import math
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
@ -89,3 +91,133 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||
def fast_ssim(img1, img2):
|
||||
ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2)
|
||||
return ssim_map.mean()
|
||||
|
||||
def build_gaussian_kernel(kernel_size=5, sigma=1.0, channels=3):
|
||||
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
|
||||
x_coord = torch.arange(kernel_size)
|
||||
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
|
||||
y_grid = x_grid.t()
|
||||
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
|
||||
|
||||
mean = (kernel_size - 1)/2.
|
||||
variance = sigma**2.
|
||||
|
||||
# Calculate the 2-dimensional gaussian kernel
|
||||
gaussian_kernel = (1./(2.*math.pi*variance)) * \
|
||||
torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / \
|
||||
(2*variance))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
|
||||
|
||||
# Reshape to 2D depthwise convolutional weight
|
||||
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
|
||||
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
|
||||
|
||||
return gaussian_kernel
|
||||
|
||||
def gaussian_blur(x, kernel):
|
||||
"""Apply gaussian blur to input tensor"""
|
||||
padding = (kernel.shape[-1] - 1) // 2
|
||||
return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
|
||||
|
||||
def create_laplacian_pyramid(image, max_levels=4):
|
||||
"""Create Laplacian pyramid for an image"""
|
||||
pyramids = []
|
||||
current = image
|
||||
kernel = build_gaussian_kernel().to(image.device)
|
||||
|
||||
for _ in range(max_levels):
|
||||
# Blur and downsample
|
||||
blurred = gaussian_blur(current, kernel)
|
||||
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
# Upsample and subtract
|
||||
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
|
||||
laplace = current - up
|
||||
|
||||
pyramids.append(laplace)
|
||||
current = down
|
||||
|
||||
pyramids.append(current) # Add the final residual
|
||||
return pyramids
|
||||
|
||||
def laplacian_pyramid_loss(pred, target, max_levels=4, weights=None):
|
||||
"""Compute Laplacian Pyramid Loss between predicted and target images"""
|
||||
if weights is None:
|
||||
weights = [1.0] * (max_levels + 1)
|
||||
|
||||
pred_pyramids = create_laplacian_pyramid(pred, max_levels)
|
||||
target_pyramids = create_laplacian_pyramid(target, max_levels)
|
||||
|
||||
loss = 0
|
||||
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
|
||||
loss += weight * torch.abs(pred_lap - target_lap).mean()
|
||||
|
||||
return loss
|
||||
|
||||
class LaplacianPyramidLoss(torch.nn.Module):
|
||||
def __init__(self, max_levels=4, channels=3, kernel_size=5, sigma=1.0):
|
||||
super().__init__()
|
||||
self.max_levels = max_levels
|
||||
self.kernel = build_gaussian_kernel(kernel_size, sigma, channels)
|
||||
|
||||
def forward(self, pred, target, weights=None):
|
||||
if weights is None:
|
||||
weights = [1.0] * (self.max_levels + 1)
|
||||
|
||||
# Move kernel to the same device as input
|
||||
kernel = self.kernel.to(pred.device)
|
||||
|
||||
pred_pyramids = self.create_laplacian_pyramid(pred, kernel)
|
||||
target_pyramids = self.create_laplacian_pyramid(target, kernel)
|
||||
|
||||
loss = 0
|
||||
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
|
||||
loss += weight * torch.abs(pred_lap - target_lap).mean()
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def create_laplacian_pyramid(image, kernel, max_levels=4):
|
||||
pyramids = []
|
||||
current = image
|
||||
|
||||
for _ in range(max_levels):
|
||||
# Apply Gaussian blur before downsampling to prevent aliasing
|
||||
blurred = gaussian_blur(current, kernel)
|
||||
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
# Upsample and subtract from the original image
|
||||
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
|
||||
laplace = current - gaussian_blur(up, kernel) # Apply blur to upsampled image
|
||||
|
||||
pyramids.append(laplace)
|
||||
current = down
|
||||
|
||||
pyramids.append(current) # Add the final residual
|
||||
return pyramids
|
||||
|
||||
class InvDepthSmoothnessLoss(nn.Module):
|
||||
def __init__(self, alpha=10):
|
||||
super(InvDepthSmoothnessLoss, self).__init__()
|
||||
self.alpha = alpha # 엣지 가중치 강도를 조절하는 하이퍼파라미터
|
||||
|
||||
def forward(self, inv_depth, image):
|
||||
# 역깊이 맵의 그래디언트 계산
|
||||
dx_inv_depth = torch.abs(inv_depth[:, :, :-1] - inv_depth[:, :, 1:])
|
||||
dy_inv_depth = torch.abs(inv_depth[:, :-1, :] - inv_depth[:, 1:, :])
|
||||
|
||||
# 이미지의 그래디언트 계산
|
||||
dx_image = torch.mean(torch.abs(image[:, :, :-1] - image[:, :, 1:]), 1, keepdim=True)
|
||||
dy_image = torch.mean(torch.abs(image[:, :-1, :] - image[:, 1:, :]), 1, keepdim=True)
|
||||
|
||||
# 이미지 그래디언트에 기반한 가중치 계산
|
||||
weight_x = torch.exp(-self.alpha * dx_image)
|
||||
weight_y = torch.exp(-self.alpha * dy_image)
|
||||
|
||||
# Smoothness loss 계산
|
||||
smoothness_x = dx_inv_depth * weight_x
|
||||
smoothness_y = dy_inv_depth * weight_y
|
||||
|
||||
return torch.mean(smoothness_x) + torch.mean(smoothness_y)
|
Loading…
Reference in New Issue
Block a user