mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-06-26 18:18:11 +00:00
add aug_gs code
This commit is contained in:
parent
54c035f783
commit
48ceb9419b
208
augment.py
Normal file
208
augment.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
from utils.colmap_utils import *
|
||||||
|
from utils.bundle_utils import cluster_cameras
|
||||||
|
from utils.aug_utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def augment(colmap_path, image_path, augment_path, camera_order, visibility_aware_culling, compare_center_patch):
|
||||||
|
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
|
||||||
|
sorted_keys = cluster_cameras(colmap_path, camera_order)
|
||||||
|
|
||||||
|
points3d = []
|
||||||
|
points3d_rgb = []
|
||||||
|
for key in sorted(colmap_points3D.keys()):
|
||||||
|
points3d.append(colmap_points3D[key].xyz)
|
||||||
|
points3d_rgb.append(colmap_points3D[key].rgb)
|
||||||
|
points3d = np.array(points3d)
|
||||||
|
points3d_rgb = np.array(points3d_rgb)
|
||||||
|
|
||||||
|
image_sample = cv2.imread(os.path.join(image_path, colmap_images[sorted_keys[0]].name))
|
||||||
|
intrinsics_camera = compute_intrinsics(colmap_cameras, image_sample.shape[1], image_sample.shape[0])
|
||||||
|
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
roots = {}
|
||||||
|
pbar = tqdm(len(sorted_keys))
|
||||||
|
for view_idx in pbar:
|
||||||
|
view = sorted_keys[view_idx]
|
||||||
|
view_root, augmented_count = image_quadtree_augmentation(
|
||||||
|
view,
|
||||||
|
colmap_cameras,
|
||||||
|
colmap_images,
|
||||||
|
colmap_points3D,
|
||||||
|
points3d,
|
||||||
|
points3d_rgb,
|
||||||
|
intrinsics_camera,
|
||||||
|
rotations_image,
|
||||||
|
translations_image,
|
||||||
|
visibility_aware_culling,
|
||||||
|
)
|
||||||
|
count += augmented_count
|
||||||
|
pbar.set_description(f"{count} points augmented")
|
||||||
|
roots[view] = view_root
|
||||||
|
|
||||||
|
for view1_idx in tqdm(range(len(sorted_keys))):
|
||||||
|
for view2_idx in [view_idx + 6,
|
||||||
|
view_idx + 5,
|
||||||
|
view_idx + 4,
|
||||||
|
view_idx + 3,
|
||||||
|
view_idx + 2,
|
||||||
|
view_idx + 1,
|
||||||
|
view_idx - 1,
|
||||||
|
view_idx - 2,
|
||||||
|
view_idx - 3,
|
||||||
|
view_idx - 4,
|
||||||
|
view_idx - 5,
|
||||||
|
view_idx - 6]:
|
||||||
|
if view2_idx > len(sorted_keys) - 1:
|
||||||
|
view2_idx = view2_idx - len(sorted_keys)
|
||||||
|
view1 = sorted_keys[view1_idx]
|
||||||
|
view2 = sorted_keys[view2_idx]
|
||||||
|
view1_root = roots[view1]
|
||||||
|
view2_root = roots[view2]
|
||||||
|
|
||||||
|
image_view2 = cv2.imread(os.path.join(image_path, colmap_images[view2].name))
|
||||||
|
|
||||||
|
view1_sample_points_world, view1_sample_points_rgb = transform_sample_3d(view1,
|
||||||
|
view1_root,
|
||||||
|
colmap_images,
|
||||||
|
colmap_cameras,
|
||||||
|
intrinsics_camera,
|
||||||
|
rotations_image,
|
||||||
|
translations_image)
|
||||||
|
view1_sample_points_view2, view1_sample_points_view2_depth = project_3d_to_2d(view1_sample_points_world,
|
||||||
|
intrinsics_camera[colmap_images[view2].camera_id],
|
||||||
|
np.concatenate((np.array(rotations_image[view2]),
|
||||||
|
np.array(translations_image[view2]).reshape(3,1)),
|
||||||
|
axis=1))
|
||||||
|
points3d_view2_pixcoord, points3d_view2_depth = project_3d_to_2d(points3d,
|
||||||
|
intrinsics_camera[colmap_images[view2].camera_id],
|
||||||
|
np.concatenate((np.array(rotations_image[view2]),
|
||||||
|
np.array(translations_image[view2]).reshape(3,1)),
|
||||||
|
axis=1))
|
||||||
|
|
||||||
|
matching_log = []
|
||||||
|
for i in range(view1_sample_points_world.shape[0]):
|
||||||
|
x, y = view1_sample_points_view2[i]
|
||||||
|
corresponding_node_type = None
|
||||||
|
error = None
|
||||||
|
if (view1_sample_points_view2_depth[i] < 0) | \
|
||||||
|
(view1_sample_points_view2[i, 0] < 0) | \
|
||||||
|
(view1_sample_points_view2[i, 0] >= image_view2.shape[1]) | \
|
||||||
|
(view1_sample_points_view2[i, 1] < 0) | \
|
||||||
|
(view1_sample_points_view2[i, 1] >= image_view2.shape[0]) | \
|
||||||
|
np.isnan(view1_sample_points_view2[i]).any(axis=0):
|
||||||
|
corresponding_node_type = "culled"
|
||||||
|
matching_log.append([view2, corresponding_node_type, error])
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
view2_corresponding_node = find_leaf_node(view2_root, x, y)
|
||||||
|
if view2_corresponding_node is None:
|
||||||
|
corresponding_node_type = "missing"
|
||||||
|
else:
|
||||||
|
if view2_corresponding_node.unoccupied:
|
||||||
|
if view2_corresponding_node.depth_interpolated:
|
||||||
|
error = np.linalg.norm(view1_sample_points_view2_depth[i] - view2_corresponding_node.sampled_point_depth)
|
||||||
|
if error < 0.2 * view2_corresponding_node.sampled_point_depth:
|
||||||
|
if compare_center_patch:
|
||||||
|
try:
|
||||||
|
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
|
||||||
|
int(view1_sample_points_view2[i,1])+2,
|
||||||
|
int(view1_sample_points_view2[i, 0])-1:\
|
||||||
|
int(view1_sample_points_view2[i,0])+2]
|
||||||
|
view2_corresponding_node_patch = image_view2[int(view2_corresponding_node.sampled_point_uv[1])-1:\
|
||||||
|
int(view2_corresponding_node.sampled_point_uv[1])+2,
|
||||||
|
int(view2_corresponding_node.sampled_point_uv[0])-1:\
|
||||||
|
int(view2_corresponding_node.sampled_point_uv[0])+2]
|
||||||
|
if compare_local_texture(view1_sample_point_patch, view2_corresponding_node_patch) > 0.5:
|
||||||
|
corresponding_node_type = "sampledrejected"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "sampled"
|
||||||
|
except IndexError:
|
||||||
|
corresponding_node_type = "sampledrejected"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "sampled"
|
||||||
|
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "sampledrejected"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "depthrejected"
|
||||||
|
else:
|
||||||
|
corresponding_3d_depth = np.array(view2_corresponding_node.points3d_depths)
|
||||||
|
error = np.linalg.norm(view1_sample_points_view2_depth[i] - corresponding_3d_depth)
|
||||||
|
|
||||||
|
if np.min(error) < 0.2 * corresponding_3d_depth[np.argmin(error)]:
|
||||||
|
if compare_center_patch:
|
||||||
|
try:
|
||||||
|
point_3d_coord = points3d_view2_pixcoord[view2_corresponding_node.points3d_indices[np.argmin[error]]]
|
||||||
|
point_3d_patch = image_view2[int(point_3d_coord[1])-1:\
|
||||||
|
int(point_3d_coord[1])+2,
|
||||||
|
int(point_3d_coord[0])-1:\
|
||||||
|
int(point_3d_coord[0])+2]
|
||||||
|
view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\
|
||||||
|
int(view1_sample_points_view2[i,1])+2,
|
||||||
|
int(view1_sample_points_view2[i, 0])-1:\
|
||||||
|
int(view1_sample_points_view2[i,0])+2]
|
||||||
|
if compare_local_texture(view1_sample_point_patch, point_3d_patch) > 0.5:
|
||||||
|
corresponding_node_type = "rejectedoccupied3d"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "occupied3d"
|
||||||
|
except IndexError:
|
||||||
|
corresponding_node_type = "rejectedoccupied3d"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "occupied3d"
|
||||||
|
else:
|
||||||
|
corresponding_node_type = "rejectedoccupied3d"
|
||||||
|
matching_log.append([view2, corresponding_node_type, error])
|
||||||
|
|
||||||
|
sampled_points_total = []
|
||||||
|
sampled_points_rgb_total = []
|
||||||
|
sampled_points_uv_total = []
|
||||||
|
sampled_points_neighbors_uv_total = []
|
||||||
|
for view in sorted_keys:
|
||||||
|
view_root = roots[view]
|
||||||
|
leaf_nodes = []
|
||||||
|
gather_leaf_nodes(view_root, leaf_nodes)
|
||||||
|
for node in leaf_nodes:
|
||||||
|
if node.unoccupied:
|
||||||
|
if node.depth_interpolated:
|
||||||
|
if node.inference_count > 0:
|
||||||
|
if node.inference_count - node.rejection_count >= 1:
|
||||||
|
sampled_points_total.append([node.sampled_point_world])
|
||||||
|
sampled_points_rgb_total.append([node.sampled_point_rgb])
|
||||||
|
sampled_points_uv_total.append([node.sampled_point_uv])
|
||||||
|
sampled_points_neighbors_uv_total.append([node.sampled_point_neighbors_uv])
|
||||||
|
print("total_Sampled_points: ", len(sampled_points_total))
|
||||||
|
xyz = np.concatenate(sampled_points_total, axis=0)
|
||||||
|
rgb = np.concatenate(sampled_points_rgb_total, axis=0)
|
||||||
|
|
||||||
|
last_index = write_points3D_colmap_binary(colmap_points3D, xyz, rgb, augment_path)
|
||||||
|
print("last_index: ", last_index)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--colmap_path", type=str, required=True)
|
||||||
|
parser.add_argument("--image_path", type=str, required=True)
|
||||||
|
parser.add_argument("--augment_path", type=str, required=True)
|
||||||
|
parser.add_argument("--camera_order", type=str, required=True, default="covisibility")
|
||||||
|
parser.add_argument("--visibility_aware_culling",
|
||||||
|
action="store_true",
|
||||||
|
default=False)
|
||||||
|
parser.add_argument("--compare_center_patch",
|
||||||
|
action="store_true",
|
||||||
|
default=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("args.colmap_path", args.colmap_path)
|
||||||
|
print("args.image_path", args.image_path)
|
||||||
|
print("args.augment_path", args.augment_path)
|
||||||
|
print("args.camera_order", args.camera_order)
|
||||||
|
print("args.visibility_aware_culling", args.visibility_aware_culling)
|
||||||
|
print("args.compare_center_patch", args.compare_center_patch)
|
||||||
|
augment(args.colmap_path, args.image_path, args.augment_path, args.camera_order, args.visibility_aware_culling, args.compare_center_patch)
|
103
train.py
103
train.py
@ -12,7 +12,9 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from random import randint
|
from random import randint
|
||||||
from utils.loss_utils import l1_loss, ssim
|
from utils.loss_utils import l1_loss, ssim, InvDepthSmoothnessLoss, laplacian_pyramid_loss
|
||||||
|
from utils.bundle_utils import cluster_cameras, bundle_start_index_generator, adaptive_cluster
|
||||||
|
from utils.aug_utils import *
|
||||||
from gaussian_renderer import render, network_gui
|
from gaussian_renderer import render, network_gui
|
||||||
import sys
|
import sys
|
||||||
from scene import Scene, GaussianModel
|
from scene import Scene, GaussianModel
|
||||||
@ -40,7 +42,20 @@ try:
|
|||||||
except:
|
except:
|
||||||
SPARSE_ADAM_AVAILABLE = False
|
SPARSE_ADAM_AVAILABLE = False
|
||||||
|
|
||||||
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
|
def training(dataset,
|
||||||
|
opt,
|
||||||
|
pipe,
|
||||||
|
testing_iterations,
|
||||||
|
saving_iterations,
|
||||||
|
checkpoint_iterations,
|
||||||
|
checkpoint,
|
||||||
|
debug_from,
|
||||||
|
camera_order,
|
||||||
|
bundle_training,
|
||||||
|
enable_ds_lap,
|
||||||
|
lambda_ds,
|
||||||
|
lambda_lap,
|
||||||
|
):
|
||||||
|
|
||||||
if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
|
if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam":
|
||||||
sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
|
sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].")
|
||||||
@ -68,6 +83,11 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
ema_loss_for_log = 0.0
|
ema_loss_for_log = 0.0
|
||||||
ema_Ll1depth_for_log = 0.0
|
ema_Ll1depth_for_log = 0.0
|
||||||
|
|
||||||
|
if bundle_training:
|
||||||
|
sorted_keys = cluster_cameras(dataset.source_path, camera_order)
|
||||||
|
start_indices, cluster_sizes = bundle_start_index_generator(sorted_keys, 20)
|
||||||
|
n_interval = 0
|
||||||
|
|
||||||
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
||||||
first_iter += 1
|
first_iter += 1
|
||||||
for iteration in range(first_iter, opt.iterations + 1):
|
for iteration in range(first_iter, opt.iterations + 1):
|
||||||
@ -94,7 +114,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
if iteration % 1000 == 0:
|
if iteration % 1000 == 0:
|
||||||
gaussians.oneupSHdegree()
|
gaussians.oneupSHdegree()
|
||||||
|
|
||||||
# Pick a random Camera
|
if bundle_training and iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and cluster_sizes[n_interval] < 160 and iteration % 100 > 80:
|
||||||
|
densification_viewpoint_stack = scene.getTrainCameras().copy()
|
||||||
|
selected_idx = adaptive_cluster(start_indices[n_interval], sorted_keys, cluster_sizes[n_interval])
|
||||||
|
if selected_idx >= len(densification_viewpoint_stack):
|
||||||
|
selected_idx = selected_idx % len(densification_viewpoint_stack)
|
||||||
|
viewpoint_cam = densification_viewpoint_stack[selected_idx]
|
||||||
|
|
||||||
|
else:
|
||||||
if not viewpoint_stack:
|
if not viewpoint_stack:
|
||||||
viewpoint_stack = scene.getTrainCameras().copy()
|
viewpoint_stack = scene.getTrainCameras().copy()
|
||||||
viewpoint_indices = list(range(len(viewpoint_stack)))
|
viewpoint_indices = list(range(len(viewpoint_stack)))
|
||||||
@ -123,7 +150,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
else:
|
else:
|
||||||
ssim_value = ssim(image, gt_image)
|
ssim_value = ssim(image, gt_image)
|
||||||
|
|
||||||
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value)
|
if enable_ds_lap:
|
||||||
|
ds_loss = InvDepthSmoothnessLoss()(render_pkg["depth"], image)
|
||||||
|
lap_loss = laplacian_pyramid_loss(image.unsqueeze(0), gt_image.unsqueeze(0))
|
||||||
|
else:
|
||||||
|
ds_loss = 0.0
|
||||||
|
lap_loss = 0.0
|
||||||
|
|
||||||
|
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + lambda_ds * ds_loss + lambda_lap * lap_loss
|
||||||
|
|
||||||
# Depth regularization
|
# Depth regularization
|
||||||
Ll1depth_pure = 0.0
|
Ll1depth_pure = 0.0
|
||||||
@ -155,7 +189,24 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
|
|
||||||
# Log and save
|
# Log and save
|
||||||
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp)
|
training_report(tb_writer,
|
||||||
|
iteration,
|
||||||
|
Ll1,
|
||||||
|
loss,
|
||||||
|
l1_loss,
|
||||||
|
iter_start.elapsed_time(iter_end),
|
||||||
|
testing_iterations,
|
||||||
|
scene,
|
||||||
|
render,
|
||||||
|
(pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp),
|
||||||
|
dataset.train_test_exp,
|
||||||
|
1.0 - ssim_value,
|
||||||
|
ds_loss,
|
||||||
|
lap_loss,
|
||||||
|
lambda_ds,
|
||||||
|
lambda_lap
|
||||||
|
)
|
||||||
|
|
||||||
if (iteration in saving_iterations):
|
if (iteration in saving_iterations):
|
||||||
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
print("\n[ITER {}] Saving Gaussians".format(iteration))
|
||||||
scene.save(iteration)
|
scene.save(iteration)
|
||||||
@ -169,6 +220,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
|
|||||||
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
||||||
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
||||||
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
|
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
|
||||||
|
n_interval += 1
|
||||||
|
|
||||||
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
||||||
gaussians.reset_opacity()
|
gaussians.reset_opacity()
|
||||||
@ -211,10 +263,30 @@ def prepare_output_and_logger(args):
|
|||||||
print("Tensorboard not available: not logging progress")
|
print("Tensorboard not available: not logging progress")
|
||||||
return tb_writer
|
return tb_writer
|
||||||
|
|
||||||
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp):
|
def training_report(tb_writer,
|
||||||
|
iteration,
|
||||||
|
Ll1,
|
||||||
|
loss,
|
||||||
|
l1_loss,
|
||||||
|
elapsed,
|
||||||
|
testing_iterations,
|
||||||
|
scene : Scene,
|
||||||
|
renderFunc,
|
||||||
|
renderArgs,
|
||||||
|
train_test_exp,
|
||||||
|
ssim_loss,
|
||||||
|
ds_loss,
|
||||||
|
lap_loss,
|
||||||
|
lambda_ds,
|
||||||
|
lambda_lap):
|
||||||
if tb_writer:
|
if tb_writer:
|
||||||
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
|
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
|
||||||
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
|
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
|
||||||
|
tb_writer.add_scalar('train_loss_patches/ssim_loss', ssim_loss.item(), iteration)
|
||||||
|
tb_writer.add_scalar('train_loss_patches/ds_loss', ds_loss.item(), iteration)
|
||||||
|
tb_writer.add_scalar('train_loss_patches/lap_loss', lap_loss.item(), iteration)
|
||||||
|
tb_writer.add_scalar('train_loss_patches/lambda_ds', lambda_ds, iteration)
|
||||||
|
tb_writer.add_scalar('train_loss_patches/lambda_lap', lambda_lap, iteration)
|
||||||
tb_writer.add_scalar('iter_time', elapsed, iteration)
|
tb_writer.add_scalar('iter_time', elapsed, iteration)
|
||||||
|
|
||||||
# Report test and samples of training set
|
# Report test and samples of training set
|
||||||
@ -267,6 +339,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--disable_viewer', action='store_true', default=False)
|
parser.add_argument('--disable_viewer', action='store_true', default=False)
|
||||||
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
||||||
parser.add_argument("--start_checkpoint", type=str, default = None)
|
parser.add_argument("--start_checkpoint", type=str, default = None)
|
||||||
|
parser.add_argument("--bundle_training", action='store_true', default=False)
|
||||||
|
parser.add_argument("--camera_order", type=str, default='covisibility')
|
||||||
|
parser.add_argument("--enable_ds_lap", action='store_true', default=False)
|
||||||
|
parser.add_argument("--lambda_ds", type=float, default=0.0)
|
||||||
|
parser.add_argument("--lambda_lap", type=float, default=0.0)
|
||||||
args = parser.parse_args(sys.argv[1:])
|
args = parser.parse_args(sys.argv[1:])
|
||||||
args.save_iterations.append(args.iterations)
|
args.save_iterations.append(args.iterations)
|
||||||
|
|
||||||
@ -279,7 +356,19 @@ if __name__ == "__main__":
|
|||||||
if not args.disable_viewer:
|
if not args.disable_viewer:
|
||||||
network_gui.init(args.ip, args.port)
|
network_gui.init(args.ip, args.port)
|
||||||
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
||||||
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
|
training(lp.extract(args),
|
||||||
|
op.extract(args),
|
||||||
|
pp.extract(args),
|
||||||
|
args.test_iterations,
|
||||||
|
args.save_iterations,
|
||||||
|
args.checkpoint_iterations,
|
||||||
|
args.start_checkpoint,
|
||||||
|
args.debug_from,
|
||||||
|
args.bundle_training,
|
||||||
|
args.camera_order,
|
||||||
|
args.enable_ds_lap,
|
||||||
|
args.lambda_ds,
|
||||||
|
args.lambda_lap)
|
||||||
|
|
||||||
# All done
|
# All done
|
||||||
print("\nTraining complete.")
|
print("\nTraining complete.")
|
||||||
|
387
utils/aug_utils.py
Normal file
387
utils/aug_utils.py
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.neighbors import BallTree
|
||||||
|
from colmap.scripts.python.read_write_model import *
|
||||||
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
import rtree
|
||||||
|
from shapely.geometry import Point, box
|
||||||
|
from collections import defaultdict
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from colmap_utils import compute_extrinsics, get_colmap_data
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
class Node:
|
||||||
|
def __init__(self, x0, y0, width, height):
|
||||||
|
self.x0 = x0
|
||||||
|
self.y0 = y0
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.children = []
|
||||||
|
self.unoccupied = True
|
||||||
|
self.sampled_point_uv = None
|
||||||
|
self.sampled_point_rgb = None
|
||||||
|
self.sampled_point_depth = None
|
||||||
|
self.depth_interpolated = None
|
||||||
|
self.sampled_point_neighbours_indices = None
|
||||||
|
self.inference_count = 0
|
||||||
|
self.rejection_count = 0
|
||||||
|
self.sampled_point_world = None
|
||||||
|
self.matching_log = {}
|
||||||
|
self.points3d_indices = []
|
||||||
|
self.points3d_depths = []
|
||||||
|
self.points3d_rgb = []
|
||||||
|
self.sampled_point_neighbours_uv = None
|
||||||
|
|
||||||
|
def get_error(self, img):
|
||||||
|
# Calculate the standard deviation of the region as an error metric
|
||||||
|
region = img[self.y0:self.y0+self.height, self.x0:self.x0+self.width]
|
||||||
|
return np.std(region)
|
||||||
|
|
||||||
|
def recursive_subdivide(node, threshold, min_pixel_size, img):
|
||||||
|
if node.get_error(img) <= threshold:
|
||||||
|
return
|
||||||
|
|
||||||
|
w_1 = node.width // 2
|
||||||
|
w_2 = node.width - w_1
|
||||||
|
h_1 = node.height // 2
|
||||||
|
h_2 = node.height - h_1
|
||||||
|
|
||||||
|
if w_1 <= min_pixel_size or h_1 <= min_pixel_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create four children nodes
|
||||||
|
x1 = Node(node.x0, node.y0, w_1, h_1) # top left
|
||||||
|
recursive_subdivide(x1, threshold, min_pixel_size, img)
|
||||||
|
|
||||||
|
x2 = Node(node.x0, node.y0 + h_1, w_1, h_2) # bottom left
|
||||||
|
recursive_subdivide(x2, threshold, min_pixel_size, img)
|
||||||
|
|
||||||
|
x3 = Node(node.x0 + w_1, node.y0, w_2, h_1) # top right
|
||||||
|
recursive_subdivide(x3, threshold, min_pixel_size, img)
|
||||||
|
|
||||||
|
x4 = Node(node.x0 + w_1, node.y0 + h_1, w_2, h_2) # bottom right
|
||||||
|
recursive_subdivide(x4, threshold, min_pixel_size, img)
|
||||||
|
|
||||||
|
node.children = [x1, x2, x3, x4]
|
||||||
|
|
||||||
|
def quadtree_decomposition(img, threshold, min_pixel_size):
|
||||||
|
root = Node(0, 0, img.shape[1], img.shape[0])
|
||||||
|
recursive_subdivide(root, threshold, min_pixel_size, img)
|
||||||
|
return root
|
||||||
|
|
||||||
|
def gather_leaf_nodes(node, leaf_nodes):
|
||||||
|
if not node.children:
|
||||||
|
leaf_nodes.append(node)
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
gather_leaf_nodes(child, leaf_nodes)
|
||||||
|
|
||||||
|
def find_leaf_node(root, pixel_x, pixel_y):
|
||||||
|
if not (root.x0 <= pixel_x < root.x0 + root.width and
|
||||||
|
root.y0 <= pixel_y < root.y0 + root.height):
|
||||||
|
return None # 픽셀이 루트 노드의 범위를 벗어난 경우
|
||||||
|
|
||||||
|
current = root
|
||||||
|
while current.children:
|
||||||
|
for child in current.children:
|
||||||
|
if (child.x0 <= pixel_x < child.x0 + child.width and
|
||||||
|
child.y0 <= pixel_y < child.y0 + child.height):
|
||||||
|
current = child
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 적절한 자식 노드를 찾지 못한 경우
|
||||||
|
return current
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
def draw_quadtree(node, ax):
|
||||||
|
if not node.children:
|
||||||
|
rect = plt.Rectangle((node.x0, node.y0), node.width, node.height, fill=False, edgecolor='red')
|
||||||
|
ax.add_patch(rect)
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
draw_quadtree(child, ax)
|
||||||
|
|
||||||
|
def pixel_to_3d(point, depth, intrinsics):
|
||||||
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||||
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||||
|
|
||||||
|
x, y = point
|
||||||
|
z = depth
|
||||||
|
x_3d = (x - cx) * z / fx
|
||||||
|
y_3d = (y - cy) * z / fy
|
||||||
|
return np.array([x_3d, y_3d, z])
|
||||||
|
|
||||||
|
def project_3d_to_2d(points3d, intrinsics, extrinsics):
|
||||||
|
"""
|
||||||
|
Project a 3D point to 2D using camera intrinsics and extrinsics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- point_3d: 3D point as a numpy array (x, y, z).
|
||||||
|
- intrinsics: Camera intrinsics matrix (3x3).
|
||||||
|
- extrinsics: Camera extrinsics matrix (4x4).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- 2D point as a tuple (x, y) in pixel coordinates.
|
||||||
|
"""
|
||||||
|
point_3d_homogeneous = np.hstack((points3d, np.ones((points3d.shape[0], 1)))) # Convert to homogeneous coordinates
|
||||||
|
point_camera = extrinsics @ point_3d_homogeneous.T
|
||||||
|
|
||||||
|
point_image_homogeneous = intrinsics @ point_camera
|
||||||
|
point_2d = point_image_homogeneous[:2] / point_image_homogeneous[2]
|
||||||
|
|
||||||
|
return point_2d.T, point_image_homogeneous[2:].T
|
||||||
|
|
||||||
|
def check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices):
|
||||||
|
rect_index = rtree.index.Index()
|
||||||
|
rectangles = [((node.x0, node.y0, node.x0 + node.width, node.y0 + node.height), i) for i, node in enumerate(leaf_nodes)]
|
||||||
|
for rect, i in rectangles:
|
||||||
|
rect_index.insert(i, rect)
|
||||||
|
rectangle_indices = []
|
||||||
|
for i, (x, y) in enumerate(points3d_pix):
|
||||||
|
if near_culled_indices[i]:
|
||||||
|
continue
|
||||||
|
point = Point(x, y)
|
||||||
|
matches = list(rect_index.intersection((x, y, x, y)))
|
||||||
|
for match in matches:
|
||||||
|
if box(*rectangles[match][0]).contains(point):
|
||||||
|
rectangle_indices.append([match, i])
|
||||||
|
|
||||||
|
for index, point3d_idx in rectangle_indices:
|
||||||
|
leaf_nodes[index].unoccupied = False
|
||||||
|
leaf_nodes[index].points3d_indices.append(point3d_idx)
|
||||||
|
leaf_nodes[index].points3d_depths.append(points3d_depth[point3d_idx])
|
||||||
|
leaf_nodes[index].points3d_rgb.append(points3d_rgb[point3d_idx])
|
||||||
|
return leaf_nodes
|
||||||
|
|
||||||
|
def compute_normal_vector(points3d_cameraframe, points3d_depth, indices):
|
||||||
|
points3d_cameraframe = np.concatenate((points3d_cameraframe[:,:2]*points3d_depth.reshape(-1, 1), points3d_depth.reshape(-1, 1)), axis=1)
|
||||||
|
p1 = points3d_cameraframe[indices[:, 0]]
|
||||||
|
p2 = points3d_cameraframe[indices[:, 1]]
|
||||||
|
p3 = points3d_cameraframe[indices[:, 2]]
|
||||||
|
v1 = p2 - p1
|
||||||
|
v2 = p3 - p1
|
||||||
|
|
||||||
|
normal = np.cross(v1, v2)
|
||||||
|
a, b, c = normal[:, 0], normal[:, 1], normal[:, 2]
|
||||||
|
d = -a * p1[:, 0] - b * p1[:, 1] - c * p1[:, 2]
|
||||||
|
|
||||||
|
return a, b, c, d
|
||||||
|
|
||||||
|
def compute_depth(sample_points_cameraframe, a, b, c, d, cosine_threshold=0.01):
|
||||||
|
direction_vectors = sample_points_cameraframe
|
||||||
|
t = -d.reshape(-1,1) / np.sum(np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) * direction_vectors, axis=1).reshape(-1,1)
|
||||||
|
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
|
||||||
|
cosine_similarity = np.abs(np.sum(normal_vectors * direction_vectors, axis=1)).reshape(-1,1) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1) / np.linalg.norm(direction_vectors, axis=1).reshape(-1,1)
|
||||||
|
rejected_indices = cosine_similarity < cosine_threshold
|
||||||
|
depth = t.reshape(-1,1)*direction_vectors[:, 2:]
|
||||||
|
return depth, rejected_indices.reshape(-1)
|
||||||
|
|
||||||
|
def find_perpendicular_triangle_indices(a, b, c, cosine_threshold=0.01):
|
||||||
|
normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1)
|
||||||
|
camera_normal = np.array([0, 0, 1]).reshape(1, 3)
|
||||||
|
cosine_similarity = np.dot(normal_vectors, camera_normal.T) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1)
|
||||||
|
cosine_culled_indices = -cosine_threshold < cosine_similarity < cosine_threshold
|
||||||
|
return cosine_culled_indices
|
||||||
|
|
||||||
|
|
||||||
|
def find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera, rotations_image, translations_image, depth_cutoff = 2.):
|
||||||
|
sampled_points = []
|
||||||
|
for node in leaf_nodes:
|
||||||
|
if node.unoccupied:
|
||||||
|
sampled_points.append(node.sampled_point_uv)
|
||||||
|
sampled_points = np.array(sampled_points)
|
||||||
|
|
||||||
|
# Later we should store the 3D points' original indices
|
||||||
|
# This is because the 3D points cannot be masked during NN search.
|
||||||
|
# Near_culled_indices indicates the indices of the 3D points that are outside the camera frustum or have negative depth.
|
||||||
|
original_indices = np.arange(points3d_pix.shape[0])
|
||||||
|
original_indices = original_indices[~near_culled_indices]
|
||||||
|
|
||||||
|
# Only infrustum 3D points are used for depth interpolation.
|
||||||
|
points3d_pix = points3d_pix[~near_culled_indices]
|
||||||
|
points3d_depth = points3d_depth[~near_culled_indices]
|
||||||
|
# Construct a BallTree for nearest neighbor search.
|
||||||
|
tree = BallTree(points3d_pix, leaf_size=40)
|
||||||
|
# Query the nearest 3 neighbors for each sampled point.
|
||||||
|
distances, indices = tree.query(sampled_points, k=3)
|
||||||
|
# Inverse the camera intrinsic matrix for generating the camera rays.
|
||||||
|
inverse_intrinsics = np.linalg.inv(intrinsics_camera)
|
||||||
|
# Generate the camera rays in camera coordinate system.
|
||||||
|
sampled_points_homogeneous = np.concatenate((sampled_points, np.ones((sampled_points.shape[0], 1))), axis=1)
|
||||||
|
sampled_points_cameraframe = (inverse_intrinsics @ sampled_points_homogeneous.T).T
|
||||||
|
# Transform the 3D points to the camera coordinate system with precomputed depths.
|
||||||
|
points3d_pix_homogeneous = np.concatenate((points3d_pix, np.ones((points3d_pix.shape[0], 1))), axis=1)
|
||||||
|
points3d_cameraframe = (inverse_intrinsics @ points3d_pix_homogeneous.T).T
|
||||||
|
# Compute the normal vector and the distance of the triangle formed by the nearest 3 neighbors.
|
||||||
|
a, b, c, d = compute_normal_vector(points3d_cameraframe, points3d_depth, indices)
|
||||||
|
# Compute the depth of the sampled points from the normal vector.
|
||||||
|
sampled_points_depth, cosine_culled_indices = compute_depth(sampled_points_cameraframe, a, b, c, d)
|
||||||
|
# Reject the points that have negative depth, are NaN, or the plane constructed by the nearest 3 neighbors is parallel to the camera ray.
|
||||||
|
depth_rejected_indices = (sampled_points_depth < depth_cutoff).reshape(-1) | np.isnan(sampled_points_cameraframe).any(axis=1) | cosine_culled_indices.reshape(-1)
|
||||||
|
# Transform the sampled points to the world coordinate system.
|
||||||
|
sampled_points_world = transform_camera_to_world(np.concatenate((sampled_points_cameraframe[:,:2]*sampled_points_depth.reshape(-1,1), sampled_points_depth.reshape(-1,1)), axis=1), rotations_image, translations_image)
|
||||||
|
sampled_points_world = sampled_points_world[:, :3]
|
||||||
|
|
||||||
|
augmented_count = 0
|
||||||
|
depth_index = 0
|
||||||
|
for node in leaf_nodes:
|
||||||
|
if node.unoccupied:
|
||||||
|
if depth_rejected_indices[depth_index]:
|
||||||
|
node.depth_interpolated = False
|
||||||
|
else:
|
||||||
|
node.sampled_point_depth = sampled_points_depth[depth_index]
|
||||||
|
node.sampled_point_world = sampled_points_world[depth_index]
|
||||||
|
node.depth_interpolated = True
|
||||||
|
node.sampled_point_neighbours_indices = original_indices[indices[depth_index]]
|
||||||
|
node.sampled_point_neighbours_uv = points3d_pix[indices[depth_index]]
|
||||||
|
augmented_count += 1
|
||||||
|
depth_index += 1
|
||||||
|
|
||||||
|
return leaf_nodes, augmented_count
|
||||||
|
|
||||||
|
def transform_camera_to_world(sampled_points_cameraframe, rotations_image, translations_image):
|
||||||
|
extrinsics_image = np.concatenate((np.array(rotations_image), np.array(translations_image).reshape(3,1)), axis=1)
|
||||||
|
extrinsics_4x4 = np.concatenate((extrinsics_image, np.array([0, 0, 0, 1]).reshape(1, 4)), axis=0)
|
||||||
|
sampled_points_cameraframe_homogeneous = np.concatenate((sampled_points_cameraframe, np.ones((sampled_points_cameraframe.shape[0], 1))), axis=1)
|
||||||
|
sampled_points_worldframe = np.linalg.inv(extrinsics_4x4) @ sampled_points_cameraframe_homogeneous.T
|
||||||
|
|
||||||
|
return sampled_points_worldframe.T
|
||||||
|
|
||||||
|
def pixelwise_rgb_diff(point_a, point_b, threshold=0.3):
|
||||||
|
return np.linalg.norm(point_a - point_b) / 255. > threshold
|
||||||
|
|
||||||
|
def image_quadtree_augmentation(image_key,
|
||||||
|
image_dir,
|
||||||
|
colmap_cameras,
|
||||||
|
colmap_images,
|
||||||
|
colmap_points3D,
|
||||||
|
points3d, points3d_rgb,
|
||||||
|
intrinsics_camera,
|
||||||
|
rotations_image,
|
||||||
|
translations_image,
|
||||||
|
quadtree_std_threshold=7,
|
||||||
|
quadtree_min_pixel_size=5,
|
||||||
|
visibility_aware_culling=False):
|
||||||
|
image_name = colmap_images[image_key].name
|
||||||
|
image_path = os.path.join(image_dir, image_name)
|
||||||
|
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
||||||
|
image_grayscale = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
quadtree_root = quadtree_decomposition(image_grayscale, quadtree_std_threshold, quadtree_min_pixel_size)
|
||||||
|
|
||||||
|
leaf_nodes = []
|
||||||
|
gather_leaf_nodes(quadtree_root, leaf_nodes)
|
||||||
|
# Project 3D points onto the image plane.
|
||||||
|
points3d_pix, points3d_depth = project_3d_to_2d(points3d, intrinsics_camera[colmap_images[image_key].camera_id], np.concatenate((np.array(rotations_image[image_key]), np.array(translations_image[image_key]).reshape(3,1)), axis=1))
|
||||||
|
# Cull the points that are outside the camera frustum and have negative depth.
|
||||||
|
near_culled_indices = (points3d_pix[:, 0] < 0) | (points3d_pix[:, 0] >= image.shape[1]) | (points3d_pix[:, 1] < 0) | (points3d_pix[:, 1] >= image.shape[0]) | (points3d_depth.reshape(-1) < 0) | np.isnan(points3d_pix).any(axis=1)
|
||||||
|
points3d_pix_rgb_diff = np.zeros(points3d_pix.shape[0], dtype=np.bool_)
|
||||||
|
if visibility_aware_culling:
|
||||||
|
for i, point in enumerate(points3d_pix):
|
||||||
|
if near_culled_indices[i]:
|
||||||
|
continue
|
||||||
|
point_a = image[int(point[1]), int(point[0])]
|
||||||
|
point_b = points3d_rgb[i]
|
||||||
|
points3d_pix_rgb_diff[i] = pixelwise_rgb_diff(point_a, point_b)
|
||||||
|
near_culled_indices = near_culled_indices | points3d_pix_rgb_diff.reshape(-1)
|
||||||
|
# Check every node in the quadtree that contains projected 3D points.
|
||||||
|
leaf_nodes = check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices)
|
||||||
|
# Sample points from the leaf nodes that are not occupied by projected 3D points.
|
||||||
|
sampled_points = []
|
||||||
|
sampled_points_rgb = []
|
||||||
|
for node in leaf_nodes:
|
||||||
|
if node.unoccupied:
|
||||||
|
node.sampled_point_uv = np.array([node.x0, node.y0]) + np.random.sample(2) * np.array([node.width, node.height])
|
||||||
|
node.sampled_point_rgb = image[int(node.sampled_point_uv[1]), int(node.sampled_point_uv[0])]
|
||||||
|
sampled_points.append(node.sampled_point_uv)
|
||||||
|
sampled_points_rgb.append(node.sampled_point_rgb)
|
||||||
|
sampled_points = np.array(sampled_points)
|
||||||
|
sampled_points_rgb = np.array(sampled_points_rgb)
|
||||||
|
# Interpolate the depth of the sampled points from the nearest 3D points.
|
||||||
|
leaf_nodes, augmented_count = find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera[colmap_images[image_key].camera_id], rotations_image[image_key], translations_image[image_key])
|
||||||
|
|
||||||
|
return quadtree_root, augmented_count
|
||||||
|
|
||||||
|
def transform_sample_3d(image_key, root, colmap_images, colmap_cameras, intrinsics_camera, rotations_image, translations_image):
|
||||||
|
# Gather leaf nodes and transform the sampled 2D points to the 3D world coordinates.
|
||||||
|
leaf_nodes = []
|
||||||
|
gather_leaf_nodes(root, leaf_nodes)
|
||||||
|
sample_points_imageframe = []
|
||||||
|
sample_points_depth = []
|
||||||
|
sample_points_rgb = []
|
||||||
|
for node in leaf_nodes:
|
||||||
|
if node.unoccupied:
|
||||||
|
if node.depth_interpolated:
|
||||||
|
sample_points_imageframe.append(node.sampled_point_uv)
|
||||||
|
sample_points_depth.append(node.sampled_point_depth)
|
||||||
|
sample_points_rgb.append(node.sampled_point_rgb)
|
||||||
|
sample_points_imageframe = np.array(sample_points_imageframe)
|
||||||
|
sample_points_depth = np.array(sample_points_depth)
|
||||||
|
sample_points_rgb = np.array(sample_points_rgb)
|
||||||
|
sample_points_cameraframe = (np.linalg.inv(intrinsics_camera[colmap_images[image_key].camera_id]) @ np.concatenate((sample_points_imageframe, np.ones((sample_points_imageframe.shape[0], 1))), axis=1).T).T
|
||||||
|
sample_points_cameraframe = np.concatenate((sample_points_cameraframe[:,:2]*sample_points_depth.reshape(-1,1), sample_points_depth.reshape(-1, 1)), axis=1)
|
||||||
|
sample_points_worldframe = transform_camera_to_world(sample_points_cameraframe, rotations_image[image_key], translations_image[image_key])
|
||||||
|
return sample_points_worldframe[:,:-1], sample_points_rgb
|
||||||
|
|
||||||
|
def write_points3D_colmap_binary(points3D, xyz, rgb, file_path):
|
||||||
|
with open(file_path, "wb") as fid:
|
||||||
|
write_next_bytes(fid, len(points3D) + len(xyz), "Q")
|
||||||
|
for j, (_, pt) in enumerate(points3D.items()):
|
||||||
|
write_next_bytes(fid, pt.id, "Q")
|
||||||
|
write_next_bytes(fid, pt.xyz.tolist(), "ddd")
|
||||||
|
write_next_bytes(fid, pt.rgb.tolist(), "BBB")
|
||||||
|
write_next_bytes(fid, pt.error, "d")
|
||||||
|
track_length = pt.image_ids.shape[0]
|
||||||
|
write_next_bytes(fid, track_length, "Q")
|
||||||
|
for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
|
||||||
|
write_next_bytes(fid, [image_id, point2D_id], "ii")
|
||||||
|
id = points3D[max(points3D.keys())].id + 1
|
||||||
|
print("starts from id=", id)
|
||||||
|
for i in range(len(xyz)):
|
||||||
|
write_next_bytes(fid, id + i, "Q")
|
||||||
|
write_next_bytes(fid, xyz[i].tolist(), "ddd")
|
||||||
|
write_next_bytes(fid, rgb[i].tolist(), "BBB")
|
||||||
|
write_next_bytes(fid, 0.0, "d")
|
||||||
|
write_next_bytes(fid, 1, "Q")
|
||||||
|
write_next_bytes(fid, [0, 0], "ii")
|
||||||
|
return i + id
|
||||||
|
|
||||||
|
def compare_local_texture(patch1, patch2):
|
||||||
|
"""
|
||||||
|
Compare two patches using a Gaussian kernel.
|
||||||
|
If the patch size is not 3x3, apply zero padding.
|
||||||
|
"""
|
||||||
|
# Define a 3x3 Gaussian kernel (sigma=1.0)
|
||||||
|
gaussian_kernel = np.array([
|
||||||
|
[0.077847, 0.123317, 0.077847],
|
||||||
|
[0.123317, 0.195346, 0.123317],
|
||||||
|
[0.077847, 0.123317, 0.077847]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Check the patch size and apply zero padding if it is not 3x3.
|
||||||
|
def pad_to_3x3(patch):
|
||||||
|
if patch.shape[:2] != (3,3):
|
||||||
|
padded = np.zeros((3,3,3) if len(patch.shape)==3 else (3,3))
|
||||||
|
h, w = patch.shape[:2]
|
||||||
|
y_start = (3-h)//2
|
||||||
|
x_start = (3-w)//2
|
||||||
|
padded[y_start:y_start+h, x_start:x_start+w] = patch
|
||||||
|
return padded
|
||||||
|
return patch
|
||||||
|
|
||||||
|
patch1 = pad_to_3x3(patch1)
|
||||||
|
patch2 = pad_to_3x3(patch2)
|
||||||
|
|
||||||
|
if len(patch1.shape) == 3: # RGB image
|
||||||
|
# Apply weights to each channel and calculate the difference.
|
||||||
|
weighted_diff = np.sum([
|
||||||
|
np.sum(gaussian_kernel * (patch1[:,:,c] - patch2[:,:,c])**2)
|
||||||
|
for c in range(3)
|
||||||
|
])
|
||||||
|
else: # Grayscale image
|
||||||
|
weighted_diff = np.sum(gaussian_kernel * (patch1 - patch2)**2)
|
||||||
|
|
||||||
|
return np.sqrt(weighted_diff)/255.
|
156
utils/bundle_utils.py
Normal file
156
utils/bundle_utils.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import random
|
||||||
|
import os
|
||||||
|
from colmap.scripts.python.read_write_model import *
|
||||||
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
from colmap_utils import compute_extrinsics, compute_intrinsics, get_colmap_data
|
||||||
|
import cv2
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_covisibility_matrix(images, points3D):
|
||||||
|
n_images = len(images)
|
||||||
|
covisibility_matrix = np.zeros((n_images, n_images))
|
||||||
|
|
||||||
|
id_to_idx = {img_id: idx for idx, img_id in enumerate(images.keys())}
|
||||||
|
idx_to_id = {idx: img_id for img_id, idx in id_to_idx.items()}
|
||||||
|
|
||||||
|
for point3D in points3D.values():
|
||||||
|
image_ids = point3D.image_ids
|
||||||
|
for i in range(len(image_ids)):
|
||||||
|
for j in range(i+1, len(image_ids)):
|
||||||
|
id1, id2 = image_ids[i], image_ids[j]
|
||||||
|
idx1, idx2 = id_to_idx[id1], id_to_idx[id2]
|
||||||
|
covisibility_matrix[idx1, idx2] += 1
|
||||||
|
covisibility_matrix[idx2, idx1] += 1
|
||||||
|
|
||||||
|
return covisibility_matrix, id_to_idx, idx_to_id
|
||||||
|
|
||||||
|
def create_covisibility_graph(covisibility_matrix, idx_to_id):
|
||||||
|
graph = defaultdict(dict)
|
||||||
|
for i in range(len(covisibility_matrix)):
|
||||||
|
for j in range(len(covisibility_matrix)):
|
||||||
|
if i != j and covisibility_matrix[i,j] > 0:
|
||||||
|
id1 = idx_to_id[i]
|
||||||
|
id2 = idx_to_id[j]
|
||||||
|
graph[id1][id2] = covisibility_matrix[i,j]
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def create_sequence_from_covisibility_graph(covisibility_graph, min_covisibility=20):
|
||||||
|
if not covisibility_graph:
|
||||||
|
print("No covisibility graph found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
start_node = max(covisibility_graph.keys(),
|
||||||
|
key=lambda k: sum(1 for v in covisibility_graph[k].values() if v >= min_covisibility))
|
||||||
|
|
||||||
|
visited = set([start_node])
|
||||||
|
sequence = [start_node]
|
||||||
|
current = start_node
|
||||||
|
|
||||||
|
while len(sequence) < len(covisibility_graph):
|
||||||
|
next_node = None
|
||||||
|
max_covisibility = -1
|
||||||
|
|
||||||
|
for neighbor, covisibility in covisibility_graph[current].items():
|
||||||
|
if neighbor not in visited and covisibility > max_covisibility and covisibility >= min_covisibility:
|
||||||
|
next_node = neighbor
|
||||||
|
max_covisibility = covisibility
|
||||||
|
|
||||||
|
if next_node is None:
|
||||||
|
for node in covisibility_graph:
|
||||||
|
if node not in visited:
|
||||||
|
for seq_node in sequence:
|
||||||
|
covisibility = covisibility_graph[node].get(seq_node, 0)
|
||||||
|
if covisibility > max_covisibility and covisibility >= min_covisibility:
|
||||||
|
max_covisibility = covisibility
|
||||||
|
next_node = node
|
||||||
|
|
||||||
|
if next_node is None:
|
||||||
|
next_node = min(set(covisibility_graph.keys()) - visited)
|
||||||
|
|
||||||
|
current = next_node
|
||||||
|
visited.add(current)
|
||||||
|
sequence.append(current)
|
||||||
|
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def cluster_cameras(model_path, camera_order):
|
||||||
|
colmap_path = os.path.join(model_path, 'sparse/0')
|
||||||
|
colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path)
|
||||||
|
if camera_order == 'covisibility':
|
||||||
|
covisibility_matrix, id_to_idx, idx_to_id = build_covisibility_matrix(colmap_images, colmap_points3D)
|
||||||
|
covisibility_graph = create_covisibility_graph(covisibility_matrix, idx_to_id)
|
||||||
|
covisibility_sequence = create_sequence_from_covisibility_graph(covisibility_graph)
|
||||||
|
|
||||||
|
image_id_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
|
||||||
|
image_id_name_sorted = sorted(image_id_name, key=lambda x: x[1])
|
||||||
|
test_id = []
|
||||||
|
train_id = []
|
||||||
|
count = 0
|
||||||
|
train_only_idx = []
|
||||||
|
for i, (id, name) in enumerate(image_id_name_sorted):
|
||||||
|
if i % 8 == 0:
|
||||||
|
test_id.append(id)
|
||||||
|
else:
|
||||||
|
train_id.append(id)
|
||||||
|
train_only_idx.append(count)
|
||||||
|
count+=1
|
||||||
|
|
||||||
|
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||||
|
|
||||||
|
train_only_visibility_idx = []
|
||||||
|
for id in covisibility_sequence:
|
||||||
|
if id in train_id:
|
||||||
|
train_only_visibility_idx.append(id)
|
||||||
|
train_only_visibility_idx = np.array(train_only_visibility_idx)
|
||||||
|
sorted_keys = train_only_visibility_idx # sorted_indices 대신 sorted_keys로 할당
|
||||||
|
|
||||||
|
elif camera_order == 'PCA':
|
||||||
|
image_idx_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()]
|
||||||
|
image_idx_name_sorted = sorted(image_idx_name, key=lambda x: x[1])
|
||||||
|
test_idx = []
|
||||||
|
train_idx = []
|
||||||
|
for i, (idx, name) in enumerate(image_idx_name_sorted):
|
||||||
|
if i % 8 == 0:
|
||||||
|
test_idx.append(idx)
|
||||||
|
else:
|
||||||
|
train_idx.append(idx)
|
||||||
|
|
||||||
|
rotations_image, translations_image = compute_extrinsics(colmap_images)
|
||||||
|
|
||||||
|
cam_center = []
|
||||||
|
key = []
|
||||||
|
for idx in train_idx:
|
||||||
|
cam_center.append((-rotations_image[idx].T @ translations_image[idx].reshape(3,1)))
|
||||||
|
key.append(idx)
|
||||||
|
|
||||||
|
cam_center = np.array(cam_center)[:,:,0]
|
||||||
|
pca = PCA(n_components=2)
|
||||||
|
cam_center_2d = pca.fit_transform(cam_center)
|
||||||
|
|
||||||
|
center_cam_center = np.mean(cam_center_2d, axis=0)
|
||||||
|
centered_cam_center = cam_center_2d - center_cam_center
|
||||||
|
angles = np.arctan2(centered_cam_center[:, 1], centered_cam_center[:, 0])
|
||||||
|
sorted_indices = np.argsort(angles)
|
||||||
|
sorted_cam_centers = cam_center_2d[sorted_indices]
|
||||||
|
sorted_keys = np.array(key)[sorted_indices]
|
||||||
|
|
||||||
|
return sorted_keys
|
||||||
|
|
||||||
|
def bundle_start_index_generator(sorted_keys, initial_interval):
|
||||||
|
start_indices = []
|
||||||
|
cluster_sizes = []
|
||||||
|
for i in range(200):
|
||||||
|
start_indices.append((initial_interval*i)%len(sorted_keys))
|
||||||
|
cluster_sizes.append(initial_interval)
|
||||||
|
|
||||||
|
return start_indices, cluster_sizes
|
||||||
|
|
||||||
|
def adaptive_cluster(start_idx, sorted_keys, cluster_size = 40, offset = 0):
|
||||||
|
idx = start_idx
|
||||||
|
indices = [sorted_keys[index % len(sorted_keys)] for index in range(idx, idx + cluster_size)]
|
||||||
|
random_index = random.choice(indices)
|
||||||
|
return random_index
|
78
utils/colmap_utils.py
Normal file
78
utils/colmap_utils.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from colmap.scripts.python.read_write_model import *
|
||||||
|
|
||||||
|
def get_colmap_data(dataset_path):
|
||||||
|
"""
|
||||||
|
Load COLMAP data from the given dataset path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path (str): Path to the dataset directory containing images, sparse/0, and dense/0 folders.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: colmap image infos
|
||||||
|
points3D: colmap 3D points
|
||||||
|
cameras: colmap camera infos
|
||||||
|
"""
|
||||||
|
images = read_images_binary(os.path.join(dataset_path, 'images.bin'))
|
||||||
|
points3D = read_points3D_binary(os.path.join(dataset_path, 'points3D.bin'))
|
||||||
|
cameras = read_cameras_binary(os.path.join(dataset_path, 'cameras.bin'))
|
||||||
|
return images, points3D, cameras
|
||||||
|
|
||||||
|
def quaternion_rotation_matrix(qw, qx, qy, qz):
|
||||||
|
"""
|
||||||
|
Convert a quaternion to a rotation matrix.
|
||||||
|
colmap uses wxyz order for quaternions.
|
||||||
|
"""
|
||||||
|
# First row of the rotation matrix
|
||||||
|
r00 = 2 * (qw * qw + qx * qx) - 1
|
||||||
|
r01 = 2 * (qx * qy - qw * qz)
|
||||||
|
r02 = 2 * (qw * qy + qx * qz)
|
||||||
|
|
||||||
|
# Second row of the rotation matrix
|
||||||
|
r10 = 2 * (qx * qy + qw * qz)
|
||||||
|
r11 = 2 * (qw * qw + qy * qy) - 1
|
||||||
|
r12 = 2 * (qy * qz - qw * qx)
|
||||||
|
|
||||||
|
# Third row of the rotation matrix
|
||||||
|
r20 = 2 * (qz * qx - qw * qy)
|
||||||
|
r21 = 2 * (qz * qy + qw * qx)
|
||||||
|
r22 = 2 * (qw * qw + qz * qz) - 1
|
||||||
|
|
||||||
|
# 3x3 rotation matrix
|
||||||
|
rot_matrix = np.array([[r00, r01, r02],
|
||||||
|
[r10, r11, r12],
|
||||||
|
[r20, r21, r22]])
|
||||||
|
|
||||||
|
return rot_matrix
|
||||||
|
|
||||||
|
def compute_intrinsic_matrix(fx, fy, cx, cy, image_width, image_height):
|
||||||
|
sx = cx/(image_width/2)
|
||||||
|
sy = cy/(image_height/2)
|
||||||
|
intrinsic_matrix = np.array([[fx/sx, 0, cx/sx], [0, fy/sy, cy/sy], [0, 0, 1]])
|
||||||
|
return intrinsic_matrix
|
||||||
|
|
||||||
|
def compute_intrinsics(colmap_cameras, image_width, image_height):
|
||||||
|
intrinsics = {}
|
||||||
|
for cam_key in colmap_cameras.keys():
|
||||||
|
intrinsic_parameters = colmap_cameras[cam_key].params
|
||||||
|
assert colmap_cameras[cam_key].model == 'PINHOLE'
|
||||||
|
intrinsic = compute_intrinsic_matrix(intrinsic_parameters[0],
|
||||||
|
intrinsic_parameters[1],
|
||||||
|
intrinsic_parameters[2],
|
||||||
|
intrinsic_parameters[3],
|
||||||
|
image_width,
|
||||||
|
image_height)
|
||||||
|
intrinsics[cam_key] = intrinsic
|
||||||
|
return intrinsics
|
||||||
|
|
||||||
|
def compute_extrinsics(colmap_images):
|
||||||
|
rotations = {}
|
||||||
|
translations = {}
|
||||||
|
for image_key in colmap_images.keys():
|
||||||
|
rotation = quaternion_rotation_matrix(colmap_images[image_key].qvec[0],
|
||||||
|
colmap_images[image_key].qvec[1],
|
||||||
|
colmap_images[image_key].qvec[2],
|
||||||
|
colmap_images[image_key].qvec[3])
|
||||||
|
translation = colmap_images[image_key].tvec
|
||||||
|
rotations[image_key] = rotation
|
||||||
|
translations[image_key] = translation
|
||||||
|
return rotations, translations
|
@ -10,6 +10,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from math import exp
|
from math import exp
|
||||||
@ -17,6 +18,7 @@ try:
|
|||||||
from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward
|
from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
import math
|
||||||
|
|
||||||
C1 = 0.01 ** 2
|
C1 = 0.01 ** 2
|
||||||
C2 = 0.03 ** 2
|
C2 = 0.03 ** 2
|
||||||
@ -89,3 +91,133 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|||||||
def fast_ssim(img1, img2):
|
def fast_ssim(img1, img2):
|
||||||
ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2)
|
ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2)
|
||||||
return ssim_map.mean()
|
return ssim_map.mean()
|
||||||
|
|
||||||
|
def build_gaussian_kernel(kernel_size=5, sigma=1.0, channels=3):
|
||||||
|
# Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
|
||||||
|
x_coord = torch.arange(kernel_size)
|
||||||
|
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
|
||||||
|
y_grid = x_grid.t()
|
||||||
|
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
|
||||||
|
|
||||||
|
mean = (kernel_size - 1)/2.
|
||||||
|
variance = sigma**2.
|
||||||
|
|
||||||
|
# Calculate the 2-dimensional gaussian kernel
|
||||||
|
gaussian_kernel = (1./(2.*math.pi*variance)) * \
|
||||||
|
torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / \
|
||||||
|
(2*variance))
|
||||||
|
|
||||||
|
# Make sure sum of values in gaussian kernel equals 1.
|
||||||
|
gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
|
||||||
|
|
||||||
|
# Reshape to 2D depthwise convolutional weight
|
||||||
|
gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
|
||||||
|
gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
|
||||||
|
|
||||||
|
return gaussian_kernel
|
||||||
|
|
||||||
|
def gaussian_blur(x, kernel):
|
||||||
|
"""Apply gaussian blur to input tensor"""
|
||||||
|
padding = (kernel.shape[-1] - 1) // 2
|
||||||
|
return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
|
||||||
|
|
||||||
|
def create_laplacian_pyramid(image, max_levels=4):
|
||||||
|
"""Create Laplacian pyramid for an image"""
|
||||||
|
pyramids = []
|
||||||
|
current = image
|
||||||
|
kernel = build_gaussian_kernel().to(image.device)
|
||||||
|
|
||||||
|
for _ in range(max_levels):
|
||||||
|
# Blur and downsample
|
||||||
|
blurred = gaussian_blur(current, kernel)
|
||||||
|
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
# Upsample and subtract
|
||||||
|
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
laplace = current - up
|
||||||
|
|
||||||
|
pyramids.append(laplace)
|
||||||
|
current = down
|
||||||
|
|
||||||
|
pyramids.append(current) # Add the final residual
|
||||||
|
return pyramids
|
||||||
|
|
||||||
|
def laplacian_pyramid_loss(pred, target, max_levels=4, weights=None):
|
||||||
|
"""Compute Laplacian Pyramid Loss between predicted and target images"""
|
||||||
|
if weights is None:
|
||||||
|
weights = [1.0] * (max_levels + 1)
|
||||||
|
|
||||||
|
pred_pyramids = create_laplacian_pyramid(pred, max_levels)
|
||||||
|
target_pyramids = create_laplacian_pyramid(target, max_levels)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
|
||||||
|
loss += weight * torch.abs(pred_lap - target_lap).mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
class LaplacianPyramidLoss(torch.nn.Module):
|
||||||
|
def __init__(self, max_levels=4, channels=3, kernel_size=5, sigma=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.max_levels = max_levels
|
||||||
|
self.kernel = build_gaussian_kernel(kernel_size, sigma, channels)
|
||||||
|
|
||||||
|
def forward(self, pred, target, weights=None):
|
||||||
|
if weights is None:
|
||||||
|
weights = [1.0] * (self.max_levels + 1)
|
||||||
|
|
||||||
|
# Move kernel to the same device as input
|
||||||
|
kernel = self.kernel.to(pred.device)
|
||||||
|
|
||||||
|
pred_pyramids = self.create_laplacian_pyramid(pred, kernel)
|
||||||
|
target_pyramids = self.create_laplacian_pyramid(target, kernel)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights):
|
||||||
|
loss += weight * torch.abs(pred_lap - target_lap).mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_laplacian_pyramid(image, kernel, max_levels=4):
|
||||||
|
pyramids = []
|
||||||
|
current = image
|
||||||
|
|
||||||
|
for _ in range(max_levels):
|
||||||
|
# Apply Gaussian blur before downsampling to prevent aliasing
|
||||||
|
blurred = gaussian_blur(current, kernel)
|
||||||
|
down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
# Upsample and subtract from the original image
|
||||||
|
up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False)
|
||||||
|
laplace = current - gaussian_blur(up, kernel) # Apply blur to upsampled image
|
||||||
|
|
||||||
|
pyramids.append(laplace)
|
||||||
|
current = down
|
||||||
|
|
||||||
|
pyramids.append(current) # Add the final residual
|
||||||
|
return pyramids
|
||||||
|
|
||||||
|
class InvDepthSmoothnessLoss(nn.Module):
|
||||||
|
def __init__(self, alpha=10):
|
||||||
|
super(InvDepthSmoothnessLoss, self).__init__()
|
||||||
|
self.alpha = alpha # 엣지 가중치 강도를 조절하는 하이퍼파라미터
|
||||||
|
|
||||||
|
def forward(self, inv_depth, image):
|
||||||
|
# 역깊이 맵의 그래디언트 계산
|
||||||
|
dx_inv_depth = torch.abs(inv_depth[:, :, :-1] - inv_depth[:, :, 1:])
|
||||||
|
dy_inv_depth = torch.abs(inv_depth[:, :-1, :] - inv_depth[:, 1:, :])
|
||||||
|
|
||||||
|
# 이미지의 그래디언트 계산
|
||||||
|
dx_image = torch.mean(torch.abs(image[:, :, :-1] - image[:, :, 1:]), 1, keepdim=True)
|
||||||
|
dy_image = torch.mean(torch.abs(image[:, :-1, :] - image[:, 1:, :]), 1, keepdim=True)
|
||||||
|
|
||||||
|
# 이미지 그래디언트에 기반한 가중치 계산
|
||||||
|
weight_x = torch.exp(-self.alpha * dx_image)
|
||||||
|
weight_y = torch.exp(-self.alpha * dy_image)
|
||||||
|
|
||||||
|
# Smoothness loss 계산
|
||||||
|
smoothness_x = dx_inv_depth * weight_x
|
||||||
|
smoothness_y = dy_inv_depth * weight_y
|
||||||
|
|
||||||
|
return torch.mean(smoothness_x) + torch.mean(smoothness_y)
|
Loading…
Reference in New Issue
Block a user