add aug_gs code

This commit is contained in:
sonbosung 2025-03-12 14:49:35 +09:00
parent 54c035f783
commit 48ceb9419b
6 changed files with 1063 additions and 13 deletions

208
augment.py Normal file
View File

@ -0,0 +1,208 @@
import argparse
import numpy as np
from tqdm import tqdm
import cv2
import os
from utils.colmap_utils import *
from utils.bundle_utils import cluster_cameras
from utils.aug_utils import *
def augment(colmap_path, image_path, augment_path, camera_order, visibility_aware_culling, compare_center_patch):
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
sorted_keys = cluster_cameras(colmap_path, camera_order)
points3d = []
points3d_rgb = []
for key in sorted(colmap_points3D.keys()):
points3d.append(colmap_points3D[key].xyz)
points3d_rgb.append(colmap_points3D[key].rgb)
points3d = np.array(points3d)
points3d_rgb = np.array(points3d_rgb)
image_sample = cv2.imread(os.path.join(image_path, colmap_images[sorted_keys[0]].name))
intrinsics_camera = compute_intrinsics(colmap_cameras, image_sample.shape[1], image_sample.shape[0])
rotations_image, translations_image = compute_extrinsics(colmap_images)
count = 0
roots = {}
pbar = tqdm(len(sorted_keys))
for view_idx in pbar:
view = sorted_keys[view_idx]
view_root, augmented_count = image_quadtree_augmentation(
view,
colmap_cameras,
colmap_images,
colmap_points3D,
points3d,
points3d_rgb,
intrinsics_camera,
rotations_image,
translations_image,
visibility_aware_culling,
)
count += augmented_count
pbar.set_description(f"{count} points augmented")
roots[view] = view_root
for view1_idx in tqdm(range(len(sorted_keys))):
for view2_idx in [view_idx + 6,
view_idx + 5,
view_idx + 4,
view_idx + 3,
view_idx + 2,
view_idx + 1,
view_idx - 1,
view_idx - 2,
view_idx - 3,
view_idx - 4,
view_idx - 5,
view_idx - 6]:
if view2_idx > len(sorted_keys) - 1:
view2_idx = view2_idx - len(sorted_keys)
view1 = sorted_keys[view1_idx]
view2 = sorted_keys[view2_idx]
view1_root = roots[view1]
view2_root = roots[view2]
image_view2 = cv2.imread(os.path.join(image_path, colmap_images[view2].name))
view1_sample_points_world, view1_sample_points_rgb = transform_sample_3d(view1,
view1_root,
colmap_images,
colmap_cameras,
intrinsics_camera,
rotations_image,
translations_image)
view1_sample_points_view2, view1_sample_points_view2_depth = project_3d_to_2d(view1_sample_points_world,
intrinsics_camera[colmap_images[view2].camera_id],
np.concatenate((np.array(rotations_image[view2]),
np.array(translations_image[view2]).reshape(3,1)),
axis=1))
points3d_view2_pixcoord, points3d_view2_depth = project_3d_to_2d(points3d,
intrinsics_camera[colmap_images[view2].camera_id],
np.concatenate((np.array(rotations_image[view2]),
np.array(translations_image[view2]).reshape(3,1)),
axis=1))
matching_log = []
for i in range(view1_sample_points_world.shape[0]):
x, y = view1_sample_points_view2[i]
corresponding_node_type = None
error = None
if (view1_sample_points_view2_depth[i] < 0) | \
(view1_sample_points_view2[i, 0] < 0) | \
(view1_sample_points_view2[i, 0] >= image_view2.shape[1]) | \
(view1_sample_points_view2[i, 1] < 0) | \
(view1_sample_points_view2[i, 1] >= image_view2.shape[0]) | \
np.isnan(view1_sample_points_view2[i]).any(axis=0):
corresponding_node_type = "culled"
matching_log.append([view2, corresponding_node_type, error])
continue
else:
view2_corresponding_node = find_leaf_node(view2_root, x, y)
if view2_corresponding_node is None:
corresponding_node_type = "missing"
else:
if view2_corresponding_node.unoccupied:
if view2_corresponding_node.depth_interpolated:
error = np.linalg.norm(view1_sample_points_view2_depth[i] - view2_corresponding_node.sampled_point_depth)
if error < 0.2 * view2_corresponding_node.sampled_point_depth:
if compare_center_patch:
try:
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
int(view1_sample_points_view2[i,1])+2,
int(view1_sample_points_view2[i, 0])-1:\
int(view1_sample_points_view2[i,0])+2]
view2_corresponding_node_patch = image_view2[int(view2_corresponding_node.sampled_point_uv[1])-1:\
int(view2_corresponding_node.sampled_point_uv[1])+2,
int(view2_corresponding_node.sampled_point_uv[0])-1:\
int(view2_corresponding_node.sampled_point_uv[0])+2]
if compare_local_texture(view1_sample_point_patch, view2_corresponding_node_patch) > 0.5:
corresponding_node_type = "sampledrejected"
else:
corresponding_node_type = "sampled"
except IndexError:
corresponding_node_type = "sampledrejected"
else:
corresponding_node_type = "sampled"
else:
corresponding_node_type = "sampledrejected"
else:
corresponding_node_type = "depthrejected"
else:
corresponding_3d_depth = np.array(view2_corresponding_node.points3d_depths)
error = np.linalg.norm(view1_sample_points_view2_depth[i] - corresponding_3d_depth)
if np.min(error) < 0.2 * corresponding_3d_depth[np.argmin(error)]:
if compare_center_patch:
try:
point_3d_coord = points3d_view2_pixcoord[view2_corresponding_node.points3d_indices[np.argmin[error]]]
point_3d_patch = image_view2[int(point_3d_coord[1])-1:\
int(point_3d_coord[1])+2,
int(point_3d_coord[0])-1:\
int(point_3d_coord[0])+2]
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
int(view1_sample_points_view2[i,1])+2,
int(view1_sample_points_view2[i, 0])-1:\
int(view1_sample_points_view2[i,0])+2]
if compare_local_texture(view1_sample_point_patch, point_3d_patch) > 0.5:
corresponding_node_type = "rejectedoccupied3d"
else:
corresponding_node_type = "occupied3d"
except IndexError:
corresponding_node_type = "rejectedoccupied3d"
else:
corresponding_node_type = "occupied3d"
else:
corresponding_node_type = "rejectedoccupied3d"
matching_log.append([view2, corresponding_node_type, error])
sampled_points_total = []
sampled_points_rgb_total = []
sampled_points_uv_total = []
sampled_points_neighbors_uv_total = []
for view in sorted_keys:
view_root = roots[view]
leaf_nodes = []
gather_leaf_nodes(view_root, leaf_nodes)
for node in leaf_nodes:
if node.unoccupied:
if node.depth_interpolated:
if node.inference_count > 0:
if node.inference_count - node.rejection_count >= 1:
sampled_points_total.append([node.sampled_point_world])
sampled_points_rgb_total.append([node.sampled_point_rgb])
sampled_points_uv_total.append([node.sampled_point_uv])
sampled_points_neighbors_uv_total.append([node.sampled_point_neighbors_uv])
print("total_Sampled_points: ", len(sampled_points_total))
xyz = np.concatenate(sampled_points_total, axis=0)
rgb = np.concatenate(sampled_points_rgb_total, axis=0)
last_index = write_points3D_colmap_binary(colmap_points3D, xyz, rgb, augment_path)
print("last_index: ", last_index)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--colmap_path", type=str, required=True)
parser.add_argument("--image_path", type=str, required=True)
parser.add_argument("--augment_path", type=str, required=True)
parser.add_argument("--camera_order", type=str, required=True, default="covisibility")
parser.add_argument("--visibility_aware_culling",
action="store_true",
default=False)
parser.add_argument("--compare_center_patch",
action="store_true",
default=False)
args = parser.parse_args()
print("args.colmap_path", args.colmap_path)
print("args.image_path", args.image_path)
print("args.augment_path", args.augment_path)
print("args.camera_order", args.camera_order)
print("args.visibility_aware_culling", args.visibility_aware_culling)
print("args.compare_center_patch", args.compare_center_patch)
augment(args.colmap_path, args.image_path, args.augment_path, args.camera_order, args.visibility_aware_culling, args.compare_center_patch)

115
train.py
View File

@ -12,7 +12,9 @@
import os
import torch
from random import randint
from utils.loss_utils import l1_loss, ssim
from utils.loss_utils import l1_loss, ssim, InvDepthSmoothnessLoss, laplacian_pyramid_loss
from utils.bundle_utils import cluster_cameras, bundle_start_index_generator, adaptive_cluster
from utils.aug_utils import *
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
@ -40,7 +42,20 @@ try:
except:
SPARSE_ADAM_AVAILABLE = False
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
def training(dataset,
opt,
pipe,
testing_iterations,
saving_iterations,
checkpoint_iterations,
checkpoint,
debug_from,
camera_order,
bundle_training,
enable_ds_lap,
lambda_ds,
lambda_lap,
):
if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
@ -68,6 +83,11 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
ema_loss_for_log = 0.0
ema_Ll1depth_for_log = 0.0
if bundle_training:
sorted_keys = cluster_cameras(dataset.source_path, camera_order)
start_indices, cluster_sizes = bundle_start_index_generator(sorted_keys, 20)
n_interval = 0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
for iteration in range(first_iter, opt.iterations + 1):
@ -94,13 +114,20 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if iteration % 1000 == 0:
gaussians.oneupSHdegree()
# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_indices = list(range(len(viewpoint_stack)))
rand_idx = randint(0, len(viewpoint_indices) - 1)
viewpoint_cam = viewpoint_stack.pop(rand_idx)
vind = viewpoint_indices.pop(rand_idx)
if bundle_training and iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and cluster_sizes[n_interval] < 160 and iteration % 100 > 80:
densification_viewpoint_stack = scene.getTrainCameras().copy()
selected_idx = adaptive_cluster(start_indices[n_interval], sorted_keys, cluster_sizes[n_interval])
if selected_idx >= len(densification_viewpoint_stack):
selected_idx = selected_idx % len(densification_viewpoint_stack)
viewpoint_cam = densification_viewpoint_stack[selected_idx]
else:
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_indices = list(range(len(viewpoint_stack)))
rand_idx = randint(0, len(viewpoint_indices) - 1)
viewpoint_cam = viewpoint_stack.pop(rand_idx)
vind = viewpoint_indices.pop(rand_idx)
# Render
if (iteration - 1) == debug_from:
@ -123,7 +150,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
else:
ssim_value = ssim(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
if enable_ds_lap:
ds_loss = InvDepthSmoothnessLoss()(render_pkg["depth"], image)
lap_loss = laplacian_pyramid_loss(image.unsqueeze(0), gt_image.unsqueeze(0))
else:
ds_loss = 0.0
lap_loss = 0.0
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + lambda_ds * ds_loss + lambda_lap * lap_loss
# Depth regularization
Ll1depth_pure = 0.0
@ -155,7 +189,24 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
progress_bar.close()
# Log and save
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp)
training_report(tb_writer,
iteration,
Ll1,
loss,
l1_loss,
iter_start.elapsed_time(iter_end),
testing_iterations,
scene,
render,
(pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp),
dataset.train_test_exp,
1.0 - ssim_value,
ds_loss,
lap_loss,
lambda_ds,
lambda_lap
)
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)
@ -169,6 +220,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
n_interval += 1
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()
@ -211,10 +263,30 @@ def prepare_output_and_logger(args):
print("Tensorboard not available: not logging progress")
return tb_writer
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
def training_report(tb_writer,
iteration,
Ll1,
loss,
l1_loss,
elapsed,
testing_iterations,
scene : Scene,
renderFunc,
renderArgs,
train_test_exp,
ssim_loss,
ds_loss,
lap_loss,
lambda_ds,
lambda_lap):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
tb_writer.add_scalar('train_loss_patches/ssim_loss', ssim_loss.item(), iteration)
tb_writer.add_scalar('train_loss_patches/ds_loss', ds_loss.item(), iteration)
tb_writer.add_scalar('train_loss_patches/lap_loss', lap_loss.item(), iteration)
tb_writer.add_scalar('train_loss_patches/lambda_ds', lambda_ds, iteration)
tb_writer.add_scalar('train_loss_patches/lambda_lap', lambda_lap, iteration)
tb_writer.add_scalar('iter_time', elapsed, iteration)
# Report test and samples of training set
@ -267,6 +339,11 @@ if __name__ == "__main__":
parser.add_argument('--disable_viewer', action='store_true', default=False)
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument("--bundle_training", action='store_true', default=False)
parser.add_argument("--camera_order", type=str, default='covisibility')
parser.add_argument("--enable_ds_lap", action='store_true', default=False)
parser.add_argument("--lambda_ds", type=float, default=0.0)
parser.add_argument("--lambda_lap", type=float, default=0.0)
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
@ -279,7 +356,19 @@ if __name__ == "__main__":
if not args.disable_viewer:
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
training(lp.extract(args),
op.extract(args),
pp.extract(args),
args.test_iterations,
args.save_iterations,
args.checkpoint_iterations,
args.start_checkpoint,
args.debug_from,
args.bundle_training,
args.camera_order,
args.enable_ds_lap,
args.lambda_ds,
args.lambda_lap)
# All done
print("\nTraining complete.")

387
utils/aug_utils.py Normal file
View File

@ -0,0 +1,387 @@
import os
import numpy as np
from sklearn.neighbors import BallTree
from colmap.scripts.python.read_write_model import *
import cv2
from tqdm import tqdm
from sklearn.decomposition import PCA
import rtree
from shapely.geometry import Point, box
from collections import defaultdict
from sklearn.decomposition import PCA
from colmap_utils import compute_extrinsics, get_colmap_data
from matplotlib import pyplot as plt
class Node:
def __init__(self, x0, y0, width, height):
self.x0 = x0
self.y0 = y0
self.width = width
self.height = height
self.children = []
self.unoccupied = True
self.sampled_point_uv = None
self.sampled_point_rgb = None
self.sampled_point_depth = None
self.depth_interpolated = None
self.sampled_point_neighbours_indices = None
self.inference_count = 0
self.rejection_count = 0
self.sampled_point_world = None
self.matching_log = {}
self.points3d_indices = []
self.points3d_depths = []
self.points3d_rgb = []
self.sampled_point_neighbours_uv = None
def get_error(self, img):
# Calculate the standard deviation of the region as an error metric
region = img[self.y0:self.y0+self.height, self.x0:self.x0+self.width]
return np.std(region)
def recursive_subdivide(node, threshold, min_pixel_size, img):
if node.get_error(img) <= threshold:
return
w_1 = node.width // 2
w_2 = node.width - w_1
h_1 = node.height // 2
h_2 = node.height - h_1
if w_1 <= min_pixel_size or h_1 <= min_pixel_size:
return
# Create four children nodes
x1 = Node(node.x0, node.y0, w_1, h_1) # top left
recursive_subdivide(x1, threshold, min_pixel_size, img)
x2 = Node(node.x0, node.y0 + h_1, w_1, h_2) # bottom left
recursive_subdivide(x2, threshold, min_pixel_size, img)
x3 = Node(node.x0 + w_1, node.y0, w_2, h_1) # top right
recursive_subdivide(x3, threshold, min_pixel_size, img)
x4 = Node(node.x0 + w_1, node.y0 + h_1, w_2, h_2) # bottom right
recursive_subdivide(x4, threshold, min_pixel_size, img)
node.children = [x1, x2, x3, x4]
def quadtree_decomposition(img, threshold, min_pixel_size):
root = Node(0, 0, img.shape[1], img.shape[0])
recursive_subdivide(root, threshold, min_pixel_size, img)
return root
def gather_leaf_nodes(node, leaf_nodes):
if not node.children:
leaf_nodes.append(node)
else:
for child in node.children:
gather_leaf_nodes(child, leaf_nodes)
def find_leaf_node(root, pixel_x, pixel_y):
if not (root.x0 <= pixel_x < root.x0 + root.width and
root.y0 <= pixel_y < root.y0 + root.height):
return None # 픽셀이 루트 노드의 범위를 벗어난 경우
current = root
while current.children:
for child in current.children:
if (child.x0 <= pixel_x < child.x0 + child.width and
child.y0 <= pixel_y < child.y0 + child.height):
current = child
break
else:
# 적절한 자식 노드를 찾지 못한 경우
return current
return current
def draw_quadtree(node, ax):
if not node.children:
rect = plt.Rectangle((node.x0, node.y0), node.width, node.height, fill=False, edgecolor='red')
ax.add_patch(rect)
else:
for child in node.children:
draw_quadtree(child, ax)
def pixel_to_3d(point, depth, intrinsics):
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
x, y = point
z = depth
x_3d = (x - cx) * z / fx
y_3d = (y - cy) * z / fy
return np.array([x_3d, y_3d, z])
def project_3d_to_2d(points3d, intrinsics, extrinsics):
"""
Project a 3D point to 2D using camera intrinsics and extrinsics.
Args:
- point_3d: 3D point as a numpy array (x, y, z).
- intrinsics: Camera intrinsics matrix (3x3).
- extrinsics: Camera extrinsics matrix (4x4).
Returns:
- 2D point as a tuple (x, y) in pixel coordinates.
"""
point_3d_homogeneous = np.hstack((points3d, np.ones((points3d.shape[0], 1)))) # Convert to homogeneous coordinates
point_camera = extrinsics @ point_3d_homogeneous.T
point_image_homogeneous = intrinsics @ point_camera
point_2d = point_image_homogeneous[:2] / point_image_homogeneous[2]
return point_2d.T, point_image_homogeneous[2:].T
def check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices):
rect_index = rtree.index.Index()
rectangles = [((node.x0, node.y0, node.x0 + node.width, node.y0 + node.height), i) for i, node in enumerate(leaf_nodes)]
for rect, i in rectangles:
rect_index.insert(i, rect)
rectangle_indices = []
for i, (x, y) in enumerate(points3d_pix):
if near_culled_indices[i]:
continue
point = Point(x, y)
matches = list(rect_index.intersection((x, y, x, y)))
for match in matches:
if box(*rectangles[match][0]).contains(point):
rectangle_indices.append([match, i])
for index, point3d_idx in rectangle_indices:
leaf_nodes[index].unoccupied = False
leaf_nodes[index].points3d_indices.append(point3d_idx)
leaf_nodes[index].points3d_depths.append(points3d_depth[point3d_idx])
leaf_nodes[index].points3d_rgb.append(points3d_rgb[point3d_idx])
return leaf_nodes
def compute_normal_vector(points3d_cameraframe, points3d_depth, indices):
points3d_cameraframe = np.concatenate((points3d_cameraframe[:,:2]*points3d_depth.reshape(-1, 1), points3d_depth.reshape(-1, 1)), axis=1)
p1 = points3d_cameraframe[indices[:, 0]]
p2 = points3d_cameraframe[indices[:, 1]]
p3 = points3d_cameraframe[indices[:, 2]]
v1 = p2 - p1
v2 = p3 - p1
normal = np.cross(v1, v2)
a, b, c = normal[:, 0], normal[:, 1], normal[:, 2]
d = -a * p1[:, 0] - b * p1[:, 1] - c * p1[:, 2]
return a, b, c, d
def compute_depth(sample_points_cameraframe, a, b, c, d, cosine_threshold=0.01):
direction_vectors = sample_points_cameraframe
t = -d.reshape(-1,1) / np.sum(np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) * direction_vectors, axis=1).reshape(-1,1)
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
cosine_similarity = np.abs(np.sum(normal_vectors * direction_vectors, axis=1)).reshape(-1,1) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1) / np.linalg.norm(direction_vectors, axis=1).reshape(-1,1)
rejected_indices = cosine_similarity < cosine_threshold
depth = t.reshape(-1,1)*direction_vectors[:, 2:]
return depth, rejected_indices.reshape(-1)
def find_perpendicular_triangle_indices(a, b, c, cosine_threshold=0.01):
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
camera_normal = np.array([0, 0, 1]).reshape(1, 3)
cosine_similarity = np.dot(normal_vectors, camera_normal.T) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1)
cosine_culled_indices = -cosine_threshold < cosine_similarity < cosine_threshold
return cosine_culled_indices
def find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera, rotations_image, translations_image, depth_cutoff = 2.):
sampled_points = []
for node in leaf_nodes:
if node.unoccupied:
sampled_points.append(node.sampled_point_uv)
sampled_points = np.array(sampled_points)
# Later we should store the 3D points' original indices
# This is because the 3D points cannot be masked during NN search.
# Near_culled_indices indicates the indices of the 3D points that are outside the camera frustum or have negative depth.
original_indices = np.arange(points3d_pix.shape[0])
original_indices = original_indices[~near_culled_indices]
# Only infrustum 3D points are used for depth interpolation.
points3d_pix = points3d_pix[~near_culled_indices]
points3d_depth = points3d_depth[~near_culled_indices]
# Construct a BallTree for nearest neighbor search.
tree = BallTree(points3d_pix, leaf_size=40)
# Query the nearest 3 neighbors for each sampled point.
distances, indices = tree.query(sampled_points, k=3)
# Inverse the camera intrinsic matrix for generating the camera rays.
inverse_intrinsics = np.linalg.inv(intrinsics_camera)
# Generate the camera rays in camera coordinate system.
sampled_points_homogeneous = np.concatenate((sampled_points, np.ones((sampled_points.shape[0], 1))), axis=1)
sampled_points_cameraframe = (inverse_intrinsics @ sampled_points_homogeneous.T).T
# Transform the 3D points to the camera coordinate system with precomputed depths.
points3d_pix_homogeneous = np.concatenate((points3d_pix, np.ones((points3d_pix.shape[0], 1))), axis=1)
points3d_cameraframe = (inverse_intrinsics @ points3d_pix_homogeneous.T).T
# Compute the normal vector and the distance of the triangle formed by the nearest 3 neighbors.
a, b, c, d = compute_normal_vector(points3d_cameraframe, points3d_depth, indices)
# Compute the depth of the sampled points from the normal vector.
sampled_points_depth, cosine_culled_indices = compute_depth(sampled_points_cameraframe, a, b, c, d)
# Reject the points that have negative depth, are NaN, or the plane constructed by the nearest 3 neighbors is parallel to the camera ray.
depth_rejected_indices = (sampled_points_depth < depth_cutoff).reshape(-1) | np.isnan(sampled_points_cameraframe).any(axis=1) | cosine_culled_indices.reshape(-1)
# Transform the sampled points to the world coordinate system.
sampled_points_world = transform_camera_to_world(np.concatenate((sampled_points_cameraframe[:,:2]*sampled_points_depth.reshape(-1,1), sampled_points_depth.reshape(-1,1)), axis=1), rotations_image, translations_image)
sampled_points_world = sampled_points_world[:, :3]
augmented_count = 0
depth_index = 0
for node in leaf_nodes:
if node.unoccupied:
if depth_rejected_indices[depth_index]:
node.depth_interpolated = False
else:
node.sampled_point_depth = sampled_points_depth[depth_index]
node.sampled_point_world = sampled_points_world[depth_index]
node.depth_interpolated = True
node.sampled_point_neighbours_indices = original_indices[indices[depth_index]]
node.sampled_point_neighbours_uv = points3d_pix[indices[depth_index]]
augmented_count += 1
depth_index += 1
return leaf_nodes, augmented_count
def transform_camera_to_world(sampled_points_cameraframe, rotations_image, translations_image):
extrinsics_image = np.concatenate((np.array(rotations_image), np.array(translations_image).reshape(3,1)), axis=1)
extrinsics_4x4 = np.concatenate((extrinsics_image, np.array([0, 0, 0, 1]).reshape(1, 4)), axis=0)
sampled_points_cameraframe_homogeneous = np.concatenate((sampled_points_cameraframe, np.ones((sampled_points_cameraframe.shape[0], 1))), axis=1)
sampled_points_worldframe = np.linalg.inv(extrinsics_4x4) @ sampled_points_cameraframe_homogeneous.T
return sampled_points_worldframe.T
def pixelwise_rgb_diff(point_a, point_b, threshold=0.3):
return np.linalg.norm(point_a - point_b) / 255. > threshold
def image_quadtree_augmentation(image_key,
image_dir,
colmap_cameras,
colmap_images,
colmap_points3D,
points3d, points3d_rgb,
intrinsics_camera,
rotations_image,
translations_image,
quadtree_std_threshold=7,
quadtree_min_pixel_size=5,
visibility_aware_culling=False):
image_name = colmap_images[image_key].name
image_path = os.path.join(image_dir, image_name)
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
image_grayscale = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
quadtree_root = quadtree_decomposition(image_grayscale, quadtree_std_threshold, quadtree_min_pixel_size)
leaf_nodes = []
gather_leaf_nodes(quadtree_root, leaf_nodes)
# Project 3D points onto the image plane.
points3d_pix, points3d_depth = project_3d_to_2d(points3d, intrinsics_camera[colmap_images[image_key].camera_id], np.concatenate((np.array(rotations_image[image_key]), np.array(translations_image[image_key]).reshape(3,1)), axis=1))
# Cull the points that are outside the camera frustum and have negative depth.
near_culled_indices = (points3d_pix[:, 0] < 0) | (points3d_pix[:, 0] >= image.shape[1]) | (points3d_pix[:, 1] < 0) | (points3d_pix[:, 1] >= image.shape[0]) | (points3d_depth.reshape(-1) < 0) | np.isnan(points3d_pix).any(axis=1)
points3d_pix_rgb_diff = np.zeros(points3d_pix.shape[0], dtype=np.bool_)
if visibility_aware_culling:
for i, point in enumerate(points3d_pix):
if near_culled_indices[i]:
continue
point_a = image[int(point[1]), int(point[0])]
point_b = points3d_rgb[i]
points3d_pix_rgb_diff[i] = pixelwise_rgb_diff(point_a, point_b)
near_culled_indices = near_culled_indices | points3d_pix_rgb_diff.reshape(-1)
# Check every node in the quadtree that contains projected 3D points.
leaf_nodes = check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices)
# Sample points from the leaf nodes that are not occupied by projected 3D points.
sampled_points = []
sampled_points_rgb = []
for node in leaf_nodes:
if node.unoccupied:
node.sampled_point_uv = np.array([node.x0, node.y0]) + np.random.sample(2) * np.array([node.width, node.height])
node.sampled_point_rgb = image[int(node.sampled_point_uv[1]), int(node.sampled_point_uv[0])]
sampled_points.append(node.sampled_point_uv)
sampled_points_rgb.append(node.sampled_point_rgb)
sampled_points = np.array(sampled_points)
sampled_points_rgb = np.array(sampled_points_rgb)
# Interpolate the depth of the sampled points from the nearest 3D points.
leaf_nodes, augmented_count = find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera[colmap_images[image_key].camera_id], rotations_image[image_key], translations_image[image_key])
return quadtree_root, augmented_count
def transform_sample_3d(image_key, root, colmap_images, colmap_cameras, intrinsics_camera, rotations_image, translations_image):
# Gather leaf nodes and transform the sampled 2D points to the 3D world coordinates.
leaf_nodes = []
gather_leaf_nodes(root, leaf_nodes)
sample_points_imageframe = []
sample_points_depth = []
sample_points_rgb = []
for node in leaf_nodes:
if node.unoccupied:
if node.depth_interpolated:
sample_points_imageframe.append(node.sampled_point_uv)
sample_points_depth.append(node.sampled_point_depth)
sample_points_rgb.append(node.sampled_point_rgb)
sample_points_imageframe = np.array(sample_points_imageframe)
sample_points_depth = np.array(sample_points_depth)
sample_points_rgb = np.array(sample_points_rgb)
sample_points_cameraframe = (np.linalg.inv(intrinsics_camera[colmap_images[image_key].camera_id]) @ np.concatenate((sample_points_imageframe, np.ones((sample_points_imageframe.shape[0], 1))), axis=1).T).T
sample_points_cameraframe = np.concatenate((sample_points_cameraframe[:,:2]*sample_points_depth.reshape(-1,1), sample_points_depth.reshape(-1, 1)), axis=1)
sample_points_worldframe = transform_camera_to_world(sample_points_cameraframe, rotations_image[image_key], translations_image[image_key])
return sample_points_worldframe[:,:-1], sample_points_rgb
def write_points3D_colmap_binary(points3D, xyz, rgb, file_path):
with open(file_path, "wb") as fid:
write_next_bytes(fid, len(points3D) + len(xyz), "Q")
for j, (_, pt) in enumerate(points3D.items()):
write_next_bytes(fid, pt.id, "Q")
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
write_next_bytes(fid, pt.error, "d")
track_length = pt.image_ids.shape[0]
write_next_bytes(fid, track_length, "Q")
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
write_next_bytes(fid, [image_id, point2D_id], "ii")
id = points3D[max(points3D.keys())].id + 1
print("starts from id=", id)
for i in range(len(xyz)):
write_next_bytes(fid, id + i, "Q")
write_next_bytes(fid, xyz[i].tolist(), "ddd")
write_next_bytes(fid, rgb[i].tolist(), "BBB")
write_next_bytes(fid, 0.0, "d")
write_next_bytes(fid, 1, "Q")
write_next_bytes(fid, [0, 0], "ii")
return i + id
def compare_local_texture(patch1, patch2):
"""
Compare two patches using a Gaussian kernel.
If the patch size is not 3x3, apply zero padding.
"""
# Define a 3x3 Gaussian kernel (sigma=1.0)
gaussian_kernel = np.array([
[0.077847, 0.123317, 0.077847],
[0.123317, 0.195346, 0.123317],
[0.077847, 0.123317, 0.077847]
])
# Check the patch size and apply zero padding if it is not 3x3.
def pad_to_3x3(patch):
if patch.shape[:2] != (3,3):
padded = np.zeros((3,3,3) if len(patch.shape)==3 else (3,3))
h, w = patch.shape[:2]
y_start = (3-h)//2
x_start = (3-w)//2
padded[y_start:y_start+h, x_start:x_start+w] = patch
return padded
return patch
patch1 = pad_to_3x3(patch1)
patch2 = pad_to_3x3(patch2)
if len(patch1.shape) == 3: # RGB image
# Apply weights to each channel and calculate the difference.
weighted_diff = np.sum([
np.sum(gaussian_kernel * (patch1[:,:,c] - patch2[:,:,c])**2)
for c in range(3)
])
else: # Grayscale image
weighted_diff = np.sum(gaussian_kernel * (patch1 - patch2)**2)
return np.sqrt(weighted_diff)/255.

156
utils/bundle_utils.py Normal file
View File

@ -0,0 +1,156 @@
import random
import os
from colmap.scripts.python.read_write_model import *
import numpy as np
from collections import defaultdict
from colmap_utils import compute_extrinsics, compute_intrinsics, get_colmap_data
import cv2
from sklearn.decomposition import PCA
def build_covisibility_matrix(images, points3D):
n_images = len(images)
covisibility_matrix = np.zeros((n_images, n_images))
id_to_idx = {img_id: idx for idx, img_id in enumerate(images.keys())}
idx_to_id = {idx: img_id for img_id, idx in id_to_idx.items()}
for point3D in points3D.values():
image_ids = point3D.image_ids
for i in range(len(image_ids)):
for j in range(i+1, len(image_ids)):
id1, id2 = image_ids[i], image_ids[j]
idx1, idx2 = id_to_idx[id1], id_to_idx[id2]
covisibility_matrix[idx1, idx2] += 1
covisibility_matrix[idx2, idx1] += 1
return covisibility_matrix, id_to_idx, idx_to_id
def create_covisibility_graph(covisibility_matrix, idx_to_id):
graph = defaultdict(dict)
for i in range(len(covisibility_matrix)):
for j in range(len(covisibility_matrix)):
if i != j and covisibility_matrix[i,j] > 0:
id1 = idx_to_id[i]
id2 = idx_to_id[j]
graph[id1][id2] = covisibility_matrix[i,j]
return graph
def create_sequence_from_covisibility_graph(covisibility_graph, min_covisibility=20):
if not covisibility_graph:
print("No covisibility graph found")
return []
start_node = max(covisibility_graph.keys(),
key=lambda k: sum(1 for v in covisibility_graph[k].values() if v >= min_covisibility))
visited = set([start_node])
sequence = [start_node]
current = start_node
while len(sequence) < len(covisibility_graph):
next_node = None
max_covisibility = -1
for neighbor, covisibility in covisibility_graph[current].items():
if neighbor not in visited and covisibility > max_covisibility and covisibility >= min_covisibility:
next_node = neighbor
max_covisibility = covisibility
if next_node is None:
for node in covisibility_graph:
if node not in visited:
for seq_node in sequence:
covisibility = covisibility_graph[node].get(seq_node, 0)
if covisibility > max_covisibility and covisibility >= min_covisibility:
max_covisibility = covisibility
next_node = node
if next_node is None:
next_node = min(set(covisibility_graph.keys()) - visited)
current = next_node
visited.add(current)
sequence.append(current)
return sequence
def cluster_cameras(model_path, camera_order):
colmap_path = os.path.join(model_path, 'sparse/0')
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
if camera_order == 'covisibility':
covisibility_matrix, id_to_idx, idx_to_id = build_covisibility_matrix(colmap_images, colmap_points3D)
covisibility_graph = create_covisibility_graph(covisibility_matrix, idx_to_id)
covisibility_sequence = create_sequence_from_covisibility_graph(covisibility_graph)
image_id_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
image_id_name_sorted = sorted(image_id_name, key=lambda x: x[1])
test_id = []
train_id = []
count = 0
train_only_idx = []
for i, (id, name) in enumerate(image_id_name_sorted):
if i % 8 == 0:
test_id.append(id)
else:
train_id.append(id)
train_only_idx.append(count)
count+=1
rotations_image, translations_image = compute_extrinsics(colmap_images)
train_only_visibility_idx = []
for id in covisibility_sequence:
if id in train_id:
train_only_visibility_idx.append(id)
train_only_visibility_idx = np.array(train_only_visibility_idx)
sorted_keys = train_only_visibility_idx # sorted_indices 대신 sorted_keys로 할당
elif camera_order == 'PCA':
image_idx_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
image_idx_name_sorted = sorted(image_idx_name, key=lambda x: x[1])
test_idx = []
train_idx = []
for i, (idx, name) in enumerate(image_idx_name_sorted):
if i % 8 == 0:
test_idx.append(idx)
else:
train_idx.append(idx)
rotations_image, translations_image = compute_extrinsics(colmap_images)
cam_center = []
key = []
for idx in train_idx:
cam_center.append((-rotations_image[idx].T @ translations_image[idx].reshape(3,1)))
key.append(idx)
cam_center = np.array(cam_center)[:,:,0]
pca = PCA(n_components=2)
cam_center_2d = pca.fit_transform(cam_center)
center_cam_center = np.mean(cam_center_2d, axis=0)
centered_cam_center = cam_center_2d - center_cam_center
angles = np.arctan2(centered_cam_center[:, 1], centered_cam_center[:, 0])
sorted_indices = np.argsort(angles)
sorted_cam_centers = cam_center_2d[sorted_indices]
sorted_keys = np.array(key)[sorted_indices]
return sorted_keys
def bundle_start_index_generator(sorted_keys, initial_interval):
start_indices = []
cluster_sizes = []
for i in range(200):
start_indices.append((initial_interval*i)%len(sorted_keys))
cluster_sizes.append(initial_interval)
return start_indices, cluster_sizes
def adaptive_cluster(start_idx, sorted_keys, cluster_size = 40, offset = 0):
idx = start_idx
indices = [sorted_keys[index % len(sorted_keys)] for index in range(idx, idx + cluster_size)]
random_index = random.choice(indices)
return random_index

78
utils/colmap_utils.py Normal file
View File

@ -0,0 +1,78 @@
from colmap.scripts.python.read_write_model import *
def get_colmap_data(dataset_path):
"""
Load COLMAP data from the given dataset path.
Args:
dataset_path (str): Path to the dataset directory containing images, sparse/0, and dense/0 folders.
Returns:
images: colmap image infos
points3D: colmap 3D points
cameras: colmap camera infos
"""
images = read_images_binary(os.path.join(dataset_path, 'images.bin'))
points3D = read_points3D_binary(os.path.join(dataset_path, 'points3D.bin'))
cameras = read_cameras_binary(os.path.join(dataset_path, 'cameras.bin'))
return images, points3D, cameras
def quaternion_rotation_matrix(qw, qx, qy, qz):
"""
Convert a quaternion to a rotation matrix.
colmap uses wxyz order for quaternions.
"""
# First row of the rotation matrix
r00 = 2 * (qw * qw + qx * qx) - 1
r01 = 2 * (qx * qy - qw * qz)
r02 = 2 * (qw * qy + qx * qz)
# Second row of the rotation matrix
r10 = 2 * (qx * qy + qw * qz)
r11 = 2 * (qw * qw + qy * qy) - 1
r12 = 2 * (qy * qz - qw * qx)
# Third row of the rotation matrix
r20 = 2 * (qz * qx - qw * qy)
r21 = 2 * (qz * qy + qw * qx)
r22 = 2 * (qw * qw + qz * qz) - 1
# 3x3 rotation matrix
rot_matrix = np.array([[r00, r01, r02],
[r10, r11, r12],
[r20, r21, r22]])
return rot_matrix
def compute_intrinsic_matrix(fx, fy, cx, cy, image_width, image_height):
sx = cx/(image_width/2)
sy = cy/(image_height/2)
intrinsic_matrix = np.array([[fx/sx, 0, cx/sx], [0, fy/sy, cy/sy], [0, 0, 1]])
return intrinsic_matrix
def compute_intrinsics(colmap_cameras, image_width, image_height):
intrinsics = {}
for cam_key in colmap_cameras.keys():
intrinsic_parameters = colmap_cameras[cam_key].params
assert colmap_cameras[cam_key].model == 'PINHOLE'
intrinsic = compute_intrinsic_matrix(intrinsic_parameters[0],
intrinsic_parameters[1],
intrinsic_parameters[2],
intrinsic_parameters[3],
image_width,
image_height)
intrinsics[cam_key] = intrinsic
return intrinsics
def compute_extrinsics(colmap_images):
rotations = {}
translations = {}
for image_key in colmap_images.keys():
rotation = quaternion_rotation_matrix(colmap_images[image_key].qvec[0],
colmap_images[image_key].qvec[1],
colmap_images[image_key].qvec[2],
colmap_images[image_key].qvec[3])
translation = colmap_images[image_key].tvec
rotations[image_key] = rotation
translations[image_key] = translation
return rotations, translations

View File

@ -10,6 +10,7 @@
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
@ -17,6 +18,7 @@ try:
from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward
except:
pass
import math
C1 = 0.01 ** 2
C2 = 0.03 ** 2
@ -89,3 +91,133 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
def fast_ssim(img1, img2):
ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2)
return ssim_map.mean()
def build_gaussian_kernel(kernel_size=5, sigma=1.0, channels=3):
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
x_coord = torch.arange(kernel_size)
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
y_grid = x_grid.t()
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (kernel_size - 1)/2.
variance = sigma**2.
# Calculate the 2-dimensional gaussian kernel
gaussian_kernel = (1./(2.*math.pi*variance)) * \
torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / \
(2*variance))
# Make sure sum of values in gaussian kernel equals 1.
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
# Reshape to 2D depthwise convolutional weight
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
return gaussian_kernel
def gaussian_blur(x, kernel):
"""Apply gaussian blur to input tensor"""
padding = (kernel.shape[-1] - 1) // 2
return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
def create_laplacian_pyramid(image, max_levels=4):
"""Create Laplacian pyramid for an image"""
pyramids = []
current = image
kernel = build_gaussian_kernel().to(image.device)
for _ in range(max_levels):
# Blur and downsample
blurred = gaussian_blur(current, kernel)
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
# Upsample and subtract
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
laplace = current - up
pyramids.append(laplace)
current = down
pyramids.append(current) # Add the final residual
return pyramids
def laplacian_pyramid_loss(pred, target, max_levels=4, weights=None):
"""Compute Laplacian Pyramid Loss between predicted and target images"""
if weights is None:
weights = [1.0] * (max_levels + 1)
pred_pyramids = create_laplacian_pyramid(pred, max_levels)
target_pyramids = create_laplacian_pyramid(target, max_levels)
loss = 0
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
loss += weight * torch.abs(pred_lap - target_lap).mean()
return loss
class LaplacianPyramidLoss(torch.nn.Module):
def __init__(self, max_levels=4, channels=3, kernel_size=5, sigma=1.0):
super().__init__()
self.max_levels = max_levels
self.kernel = build_gaussian_kernel(kernel_size, sigma, channels)
def forward(self, pred, target, weights=None):
if weights is None:
weights = [1.0] * (self.max_levels + 1)
# Move kernel to the same device as input
kernel = self.kernel.to(pred.device)
pred_pyramids = self.create_laplacian_pyramid(pred, kernel)
target_pyramids = self.create_laplacian_pyramid(target, kernel)
loss = 0
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
loss += weight * torch.abs(pred_lap - target_lap).mean()
return loss
@staticmethod
def create_laplacian_pyramid(image, kernel, max_levels=4):
pyramids = []
current = image
for _ in range(max_levels):
# Apply Gaussian blur before downsampling to prevent aliasing
blurred = gaussian_blur(current, kernel)
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
# Upsample and subtract from the original image
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
laplace = current - gaussian_blur(up, kernel) # Apply blur to upsampled image
pyramids.append(laplace)
current = down
pyramids.append(current) # Add the final residual
return pyramids
class InvDepthSmoothnessLoss(nn.Module):
def __init__(self, alpha=10):
super(InvDepthSmoothnessLoss, self).__init__()
self.alpha = alpha # 엣지 가중치 강도를 조절하는 하이퍼파라미터
def forward(self, inv_depth, image):
# 역깊이 맵의 그래디언트 계산
dx_inv_depth = torch.abs(inv_depth[:, :, :-1] - inv_depth[:, :, 1:])
dy_inv_depth = torch.abs(inv_depth[:, :-1, :] - inv_depth[:, 1:, :])
# 이미지의 그래디언트 계산
dx_image = torch.mean(torch.abs(image[:, :, :-1] - image[:, :, 1:]), 1, keepdim=True)
dy_image = torch.mean(torch.abs(image[:, :-1, :] - image[:, 1:, :]), 1, keepdim=True)
# 이미지 그래디언트에 기반한 가중치 계산
weight_x = torch.exp(-self.alpha * dx_image)
weight_y = torch.exp(-self.alpha * dy_image)
# Smoothness loss 계산
smoothness_x = dx_inv_depth * weight_x
smoothness_y = dy_inv_depth * weight_y
return torch.mean(smoothness_x) + torch.mean(smoothness_y)