diff --git a/augment.py b/augment.py new file mode 100644 index 0000000..f9d5e60 --- /dev/null +++ b/augment.py @@ -0,0 +1,208 @@ + + +import argparse +import numpy as np +from tqdm import tqdm +import cv2 +import os + +from utils.colmap_utils import * +from utils.bundle_utils import cluster_cameras +from utils.aug_utils import * + + +def augment(colmap_path, image_path, augment_path, camera_order, visibility_aware_culling, compare_center_patch): + colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path) + sorted_keys = cluster_cameras(colmap_path, camera_order) + + points3d = [] + points3d_rgb = [] + for key in sorted(colmap_points3D.keys()): + points3d.append(colmap_points3D[key].xyz) + points3d_rgb.append(colmap_points3D[key].rgb) + points3d = np.array(points3d) + points3d_rgb = np.array(points3d_rgb) + + image_sample = cv2.imread(os.path.join(image_path, colmap_images[sorted_keys[0]].name)) + intrinsics_camera = compute_intrinsics(colmap_cameras, image_sample.shape[1], image_sample.shape[0]) + rotations_image, translations_image = compute_extrinsics(colmap_images) + + count = 0 + roots = {} + pbar = tqdm(len(sorted_keys)) + for view_idx in pbar: + view = sorted_keys[view_idx] + view_root, augmented_count = image_quadtree_augmentation( + view, + colmap_cameras, + colmap_images, + colmap_points3D, + points3d, + points3d_rgb, + intrinsics_camera, + rotations_image, + translations_image, + visibility_aware_culling, + ) + count += augmented_count + pbar.set_description(f"{count} points augmented") + roots[view] = view_root + + for view1_idx in tqdm(range(len(sorted_keys))): + for view2_idx in [view_idx + 6, + view_idx + 5, + view_idx + 4, + view_idx + 3, + view_idx + 2, + view_idx + 1, + view_idx - 1, + view_idx - 2, + view_idx - 3, + view_idx - 4, + view_idx - 5, + view_idx - 6]: + if view2_idx > len(sorted_keys) - 1: + view2_idx = view2_idx - len(sorted_keys) + view1 = sorted_keys[view1_idx] + view2 = sorted_keys[view2_idx] + view1_root = roots[view1] + view2_root = roots[view2] + + image_view2 = cv2.imread(os.path.join(image_path, colmap_images[view2].name)) + + view1_sample_points_world, view1_sample_points_rgb = transform_sample_3d(view1, + view1_root, + colmap_images, + colmap_cameras, + intrinsics_camera, + rotations_image, + translations_image) + view1_sample_points_view2, view1_sample_points_view2_depth = project_3d_to_2d(view1_sample_points_world, + intrinsics_camera[colmap_images[view2].camera_id], + np.concatenate((np.array(rotations_image[view2]), + np.array(translations_image[view2]).reshape(3,1)), + axis=1)) + points3d_view2_pixcoord, points3d_view2_depth = project_3d_to_2d(points3d, + intrinsics_camera[colmap_images[view2].camera_id], + np.concatenate((np.array(rotations_image[view2]), + np.array(translations_image[view2]).reshape(3,1)), + axis=1)) + + matching_log = [] + for i in range(view1_sample_points_world.shape[0]): + x, y = view1_sample_points_view2[i] + corresponding_node_type = None + error = None + if (view1_sample_points_view2_depth[i] < 0) | \ + (view1_sample_points_view2[i, 0] < 0) | \ + (view1_sample_points_view2[i, 0] >= image_view2.shape[1]) | \ + (view1_sample_points_view2[i, 1] < 0) | \ + (view1_sample_points_view2[i, 1] >= image_view2.shape[0]) | \ + np.isnan(view1_sample_points_view2[i]).any(axis=0): + corresponding_node_type = "culled" + matching_log.append([view2, corresponding_node_type, error]) + continue + else: + view2_corresponding_node = find_leaf_node(view2_root, x, y) + if view2_corresponding_node is None: + corresponding_node_type = "missing" + else: + if view2_corresponding_node.unoccupied: + if view2_corresponding_node.depth_interpolated: + error = np.linalg.norm(view1_sample_points_view2_depth[i] - view2_corresponding_node.sampled_point_depth) + if error < 0.2 * view2_corresponding_node.sampled_point_depth: + if compare_center_patch: + try: + view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\ + int(view1_sample_points_view2[i,1])+2, + int(view1_sample_points_view2[i, 0])-1:\ + int(view1_sample_points_view2[i,0])+2] + view2_corresponding_node_patch = image_view2[int(view2_corresponding_node.sampled_point_uv[1])-1:\ + int(view2_corresponding_node.sampled_point_uv[1])+2, + int(view2_corresponding_node.sampled_point_uv[0])-1:\ + int(view2_corresponding_node.sampled_point_uv[0])+2] + if compare_local_texture(view1_sample_point_patch, view2_corresponding_node_patch) > 0.5: + corresponding_node_type = "sampledrejected" + else: + corresponding_node_type = "sampled" + except IndexError: + corresponding_node_type = "sampledrejected" + else: + corresponding_node_type = "sampled" + + else: + corresponding_node_type = "sampledrejected" + else: + corresponding_node_type = "depthrejected" + else: + corresponding_3d_depth = np.array(view2_corresponding_node.points3d_depths) + error = np.linalg.norm(view1_sample_points_view2_depth[i] - corresponding_3d_depth) + + if np.min(error) < 0.2 * corresponding_3d_depth[np.argmin(error)]: + if compare_center_patch: + try: + point_3d_coord = points3d_view2_pixcoord[view2_corresponding_node.points3d_indices[np.argmin[error]]] + point_3d_patch = image_view2[int(point_3d_coord[1])-1:\ + int(point_3d_coord[1])+2, + int(point_3d_coord[0])-1:\ + int(point_3d_coord[0])+2] + view1_sample_point_patch = image_view2[int(view1_sample_points_view2[i, 1])-1:\ + int(view1_sample_points_view2[i,1])+2, + int(view1_sample_points_view2[i, 0])-1:\ + int(view1_sample_points_view2[i,0])+2] + if compare_local_texture(view1_sample_point_patch, point_3d_patch) > 0.5: + corresponding_node_type = "rejectedoccupied3d" + else: + corresponding_node_type = "occupied3d" + except IndexError: + corresponding_node_type = "rejectedoccupied3d" + else: + corresponding_node_type = "occupied3d" + else: + corresponding_node_type = "rejectedoccupied3d" + matching_log.append([view2, corresponding_node_type, error]) + + sampled_points_total = [] + sampled_points_rgb_total = [] + sampled_points_uv_total = [] + sampled_points_neighbors_uv_total = [] + for view in sorted_keys: + view_root = roots[view] + leaf_nodes = [] + gather_leaf_nodes(view_root, leaf_nodes) + for node in leaf_nodes: + if node.unoccupied: + if node.depth_interpolated: + if node.inference_count > 0: + if node.inference_count - node.rejection_count >= 1: + sampled_points_total.append([node.sampled_point_world]) + sampled_points_rgb_total.append([node.sampled_point_rgb]) + sampled_points_uv_total.append([node.sampled_point_uv]) + sampled_points_neighbors_uv_total.append([node.sampled_point_neighbors_uv]) + print("total_Sampled_points: ", len(sampled_points_total)) + xyz = np.concatenate(sampled_points_total, axis=0) + rgb = np.concatenate(sampled_points_rgb_total, axis=0) + + last_index = write_points3D_colmap_binary(colmap_points3D, xyz, rgb, augment_path) + print("last_index: ", last_index) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--colmap_path", type=str, required=True) + parser.add_argument("--image_path", type=str, required=True) + parser.add_argument("--augment_path", type=str, required=True) + parser.add_argument("--camera_order", type=str, required=True, default="covisibility") + parser.add_argument("--visibility_aware_culling", + action="store_true", + default=False) + parser.add_argument("--compare_center_patch", + action="store_true", + default=False) + args = parser.parse_args() + print("args.colmap_path", args.colmap_path) + print("args.image_path", args.image_path) + print("args.augment_path", args.augment_path) + print("args.camera_order", args.camera_order) + print("args.visibility_aware_culling", args.visibility_aware_culling) + print("args.compare_center_patch", args.compare_center_patch) + augment(args.colmap_path, args.image_path, args.augment_path, args.camera_order, args.visibility_aware_culling, args.compare_center_patch) \ No newline at end of file diff --git a/train.py b/train.py index 8206903..813c628 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,9 @@ import os import torch from random import randint -from utils.loss_utils import l1_loss, ssim +from utils.loss_utils import l1_loss, ssim, InvDepthSmoothnessLoss, laplacian_pyramid_loss +from utils.bundle_utils import cluster_cameras, bundle_start_index_generator, adaptive_cluster +from utils.aug_utils import * from gaussian_renderer import render, network_gui import sys from scene import Scene, GaussianModel @@ -40,7 +42,20 @@ try: except: SPARSE_ADAM_AVAILABLE = False -def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from): +def training(dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, + camera_order, + bundle_training, + enable_ds_lap, + lambda_ds, + lambda_lap, + ): if not SPARSE_ADAM_AVAILABLE and opt.optimizer_type == "sparse_adam": sys.exit(f"Trying to use sparse adam but it is not installed, please install the correct rasterizer using pip install [3dgs_accel].") @@ -68,6 +83,11 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi ema_loss_for_log = 0.0 ema_Ll1depth_for_log = 0.0 + if bundle_training: + sorted_keys = cluster_cameras(dataset.source_path, camera_order) + start_indices, cluster_sizes = bundle_start_index_generator(sorted_keys, 20) + n_interval = 0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") first_iter += 1 for iteration in range(first_iter, opt.iterations + 1): @@ -94,13 +114,20 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi if iteration % 1000 == 0: gaussians.oneupSHdegree() - # Pick a random Camera - if not viewpoint_stack: - viewpoint_stack = scene.getTrainCameras().copy() - viewpoint_indices = list(range(len(viewpoint_stack))) - rand_idx = randint(0, len(viewpoint_indices) - 1) - viewpoint_cam = viewpoint_stack.pop(rand_idx) - vind = viewpoint_indices.pop(rand_idx) + if bundle_training and iteration < opt.densify_until_iter and iteration > opt.densify_from_iter and cluster_sizes[n_interval] < 160 and iteration % 100 > 80: + densification_viewpoint_stack = scene.getTrainCameras().copy() + selected_idx = adaptive_cluster(start_indices[n_interval], sorted_keys, cluster_sizes[n_interval]) + if selected_idx >= len(densification_viewpoint_stack): + selected_idx = selected_idx % len(densification_viewpoint_stack) + viewpoint_cam = densification_viewpoint_stack[selected_idx] + + else: + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_indices = list(range(len(viewpoint_stack))) + rand_idx = randint(0, len(viewpoint_indices) - 1) + viewpoint_cam = viewpoint_stack.pop(rand_idx) + vind = viewpoint_indices.pop(rand_idx) # Render if (iteration - 1) == debug_from: @@ -123,7 +150,14 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi else: ssim_value = ssim(image, gt_image) - loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + if enable_ds_lap: + ds_loss = InvDepthSmoothnessLoss()(render_pkg["depth"], image) + lap_loss = laplacian_pyramid_loss(image.unsqueeze(0), gt_image.unsqueeze(0)) + else: + ds_loss = 0.0 + lap_loss = 0.0 + + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim_value) + lambda_ds * ds_loss + lambda_lap * lap_loss # Depth regularization Ll1depth_pure = 0.0 @@ -155,7 +189,24 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi progress_bar.close() # Log and save - training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), dataset.train_test_exp) + training_report(tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background, 1., SPARSE_ADAM_AVAILABLE, None, dataset.train_test_exp), + dataset.train_test_exp, + 1.0 - ssim_value, + ds_loss, + lap_loss, + lambda_ds, + lambda_lap + ) + if (iteration in saving_iterations): print("\n[ITER {}] Saving Gaussians".format(iteration)) scene.save(iteration) @@ -169,6 +220,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: size_threshold = 20 if iteration > opt.opacity_reset_interval else None gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii) + n_interval += 1 if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter): gaussians.reset_opacity() @@ -211,10 +263,30 @@ def prepare_output_and_logger(args): print("Tensorboard not available: not logging progress") return tb_writer -def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs, train_test_exp): +def training_report(tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene : Scene, + renderFunc, + renderArgs, + train_test_exp, + ssim_loss, + ds_loss, + lap_loss, + lambda_ds, + lambda_lap): if tb_writer: tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration) tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration) + tb_writer.add_scalar('train_loss_patches/ssim_loss', ssim_loss.item(), iteration) + tb_writer.add_scalar('train_loss_patches/ds_loss', ds_loss.item(), iteration) + tb_writer.add_scalar('train_loss_patches/lap_loss', lap_loss.item(), iteration) + tb_writer.add_scalar('train_loss_patches/lambda_ds', lambda_ds, iteration) + tb_writer.add_scalar('train_loss_patches/lambda_lap', lambda_lap, iteration) tb_writer.add_scalar('iter_time', elapsed, iteration) # Report test and samples of training set @@ -267,6 +339,11 @@ if __name__ == "__main__": parser.add_argument('--disable_viewer', action='store_true', default=False) parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) parser.add_argument("--start_checkpoint", type=str, default = None) + parser.add_argument("--bundle_training", action='store_true', default=False) + parser.add_argument("--camera_order", type=str, default='covisibility') + parser.add_argument("--enable_ds_lap", action='store_true', default=False) + parser.add_argument("--lambda_ds", type=float, default=0.0) + parser.add_argument("--lambda_lap", type=float, default=0.0) args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) @@ -279,7 +356,19 @@ if __name__ == "__main__": if not args.disable_viewer: network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) + training(lp.extract(args), + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + args.bundle_training, + args.camera_order, + args.enable_ds_lap, + args.lambda_ds, + args.lambda_lap) # All done print("\nTraining complete.") diff --git a/utils/aug_utils.py b/utils/aug_utils.py new file mode 100644 index 0000000..ba4c184 --- /dev/null +++ b/utils/aug_utils.py @@ -0,0 +1,387 @@ +import os +import numpy as np +from sklearn.neighbors import BallTree +from colmap.scripts.python.read_write_model import * +import cv2 +from tqdm import tqdm +from sklearn.decomposition import PCA +import rtree +from shapely.geometry import Point, box +from collections import defaultdict +from sklearn.decomposition import PCA +from colmap_utils import compute_extrinsics, get_colmap_data +from matplotlib import pyplot as plt + +class Node: + def __init__(self, x0, y0, width, height): + self.x0 = x0 + self.y0 = y0 + self.width = width + self.height = height + self.children = [] + self.unoccupied = True + self.sampled_point_uv = None + self.sampled_point_rgb = None + self.sampled_point_depth = None + self.depth_interpolated = None + self.sampled_point_neighbours_indices = None + self.inference_count = 0 + self.rejection_count = 0 + self.sampled_point_world = None + self.matching_log = {} + self.points3d_indices = [] + self.points3d_depths = [] + self.points3d_rgb = [] + self.sampled_point_neighbours_uv = None + + def get_error(self, img): + # Calculate the standard deviation of the region as an error metric + region = img[self.y0:self.y0+self.height, self.x0:self.x0+self.width] + return np.std(region) + +def recursive_subdivide(node, threshold, min_pixel_size, img): + if node.get_error(img) <= threshold: + return + + w_1 = node.width // 2 + w_2 = node.width - w_1 + h_1 = node.height // 2 + h_2 = node.height - h_1 + + if w_1 <= min_pixel_size or h_1 <= min_pixel_size: + return + + # Create four children nodes + x1 = Node(node.x0, node.y0, w_1, h_1) # top left + recursive_subdivide(x1, threshold, min_pixel_size, img) + + x2 = Node(node.x0, node.y0 + h_1, w_1, h_2) # bottom left + recursive_subdivide(x2, threshold, min_pixel_size, img) + + x3 = Node(node.x0 + w_1, node.y0, w_2, h_1) # top right + recursive_subdivide(x3, threshold, min_pixel_size, img) + + x4 = Node(node.x0 + w_1, node.y0 + h_1, w_2, h_2) # bottom right + recursive_subdivide(x4, threshold, min_pixel_size, img) + + node.children = [x1, x2, x3, x4] + +def quadtree_decomposition(img, threshold, min_pixel_size): + root = Node(0, 0, img.shape[1], img.shape[0]) + recursive_subdivide(root, threshold, min_pixel_size, img) + return root + +def gather_leaf_nodes(node, leaf_nodes): + if not node.children: + leaf_nodes.append(node) + else: + for child in node.children: + gather_leaf_nodes(child, leaf_nodes) + +def find_leaf_node(root, pixel_x, pixel_y): + if not (root.x0 <= pixel_x < root.x0 + root.width and + root.y0 <= pixel_y < root.y0 + root.height): + return None # 픽셀이 루트 노드의 범위를 벗어난 경우 + + current = root + while current.children: + for child in current.children: + if (child.x0 <= pixel_x < child.x0 + child.width and + child.y0 <= pixel_y < child.y0 + child.height): + current = child + break + else: + # 적절한 자식 노드를 찾지 못한 경우 + return current + + return current + +def draw_quadtree(node, ax): + if not node.children: + rect = plt.Rectangle((node.x0, node.y0), node.width, node.height, fill=False, edgecolor='red') + ax.add_patch(rect) + else: + for child in node.children: + draw_quadtree(child, ax) + +def pixel_to_3d(point, depth, intrinsics): + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + + x, y = point + z = depth + x_3d = (x - cx) * z / fx + y_3d = (y - cy) * z / fy + return np.array([x_3d, y_3d, z]) + +def project_3d_to_2d(points3d, intrinsics, extrinsics): + """ + Project a 3D point to 2D using camera intrinsics and extrinsics. + + Args: + - point_3d: 3D point as a numpy array (x, y, z). + - intrinsics: Camera intrinsics matrix (3x3). + - extrinsics: Camera extrinsics matrix (4x4). + + Returns: + - 2D point as a tuple (x, y) in pixel coordinates. + """ + point_3d_homogeneous = np.hstack((points3d, np.ones((points3d.shape[0], 1)))) # Convert to homogeneous coordinates + point_camera = extrinsics @ point_3d_homogeneous.T + + point_image_homogeneous = intrinsics @ point_camera + point_2d = point_image_homogeneous[:2] / point_image_homogeneous[2] + + return point_2d.T, point_image_homogeneous[2:].T + +def check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices): + rect_index = rtree.index.Index() + rectangles = [((node.x0, node.y0, node.x0 + node.width, node.y0 + node.height), i) for i, node in enumerate(leaf_nodes)] + for rect, i in rectangles: + rect_index.insert(i, rect) + rectangle_indices = [] + for i, (x, y) in enumerate(points3d_pix): + if near_culled_indices[i]: + continue + point = Point(x, y) + matches = list(rect_index.intersection((x, y, x, y))) + for match in matches: + if box(*rectangles[match][0]).contains(point): + rectangle_indices.append([match, i]) + + for index, point3d_idx in rectangle_indices: + leaf_nodes[index].unoccupied = False + leaf_nodes[index].points3d_indices.append(point3d_idx) + leaf_nodes[index].points3d_depths.append(points3d_depth[point3d_idx]) + leaf_nodes[index].points3d_rgb.append(points3d_rgb[point3d_idx]) + return leaf_nodes + +def compute_normal_vector(points3d_cameraframe, points3d_depth, indices): + points3d_cameraframe = np.concatenate((points3d_cameraframe[:,:2]*points3d_depth.reshape(-1, 1), points3d_depth.reshape(-1, 1)), axis=1) + p1 = points3d_cameraframe[indices[:, 0]] + p2 = points3d_cameraframe[indices[:, 1]] + p3 = points3d_cameraframe[indices[:, 2]] + v1 = p2 - p1 + v2 = p3 - p1 + + normal = np.cross(v1, v2) + a, b, c = normal[:, 0], normal[:, 1], normal[:, 2] + d = -a * p1[:, 0] - b * p1[:, 1] - c * p1[:, 2] + + return a, b, c, d + +def compute_depth(sample_points_cameraframe, a, b, c, d, cosine_threshold=0.01): + direction_vectors = sample_points_cameraframe + t = -d.reshape(-1,1) / np.sum(np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) * direction_vectors, axis=1).reshape(-1,1) + normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) + cosine_similarity = np.abs(np.sum(normal_vectors * direction_vectors, axis=1)).reshape(-1,1) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1) / np.linalg.norm(direction_vectors, axis=1).reshape(-1,1) + rejected_indices = cosine_similarity < cosine_threshold + depth = t.reshape(-1,1)*direction_vectors[:, 2:] + return depth, rejected_indices.reshape(-1) + +def find_perpendicular_triangle_indices(a, b, c, cosine_threshold=0.01): + normal_vectors = np.concatenate((a.reshape(-1,1), b.reshape(-1,1), c.reshape(-1,1)), axis=1) + camera_normal = np.array([0, 0, 1]).reshape(1, 3) + cosine_similarity = np.dot(normal_vectors, camera_normal.T) / np.linalg.norm(normal_vectors, axis=1).reshape(-1,1) + cosine_culled_indices = -cosine_threshold < cosine_similarity < cosine_threshold + return cosine_culled_indices + + +def find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera, rotations_image, translations_image, depth_cutoff = 2.): + sampled_points = [] + for node in leaf_nodes: + if node.unoccupied: + sampled_points.append(node.sampled_point_uv) + sampled_points = np.array(sampled_points) + + # Later we should store the 3D points' original indices + # This is because the 3D points cannot be masked during NN search. + # Near_culled_indices indicates the indices of the 3D points that are outside the camera frustum or have negative depth. + original_indices = np.arange(points3d_pix.shape[0]) + original_indices = original_indices[~near_culled_indices] + + # Only infrustum 3D points are used for depth interpolation. + points3d_pix = points3d_pix[~near_culled_indices] + points3d_depth = points3d_depth[~near_culled_indices] + # Construct a BallTree for nearest neighbor search. + tree = BallTree(points3d_pix, leaf_size=40) + # Query the nearest 3 neighbors for each sampled point. + distances, indices = tree.query(sampled_points, k=3) + # Inverse the camera intrinsic matrix for generating the camera rays. + inverse_intrinsics = np.linalg.inv(intrinsics_camera) + # Generate the camera rays in camera coordinate system. + sampled_points_homogeneous = np.concatenate((sampled_points, np.ones((sampled_points.shape[0], 1))), axis=1) + sampled_points_cameraframe = (inverse_intrinsics @ sampled_points_homogeneous.T).T + # Transform the 3D points to the camera coordinate system with precomputed depths. + points3d_pix_homogeneous = np.concatenate((points3d_pix, np.ones((points3d_pix.shape[0], 1))), axis=1) + points3d_cameraframe = (inverse_intrinsics @ points3d_pix_homogeneous.T).T + # Compute the normal vector and the distance of the triangle formed by the nearest 3 neighbors. + a, b, c, d = compute_normal_vector(points3d_cameraframe, points3d_depth, indices) + # Compute the depth of the sampled points from the normal vector. + sampled_points_depth, cosine_culled_indices = compute_depth(sampled_points_cameraframe, a, b, c, d) + # Reject the points that have negative depth, are NaN, or the plane constructed by the nearest 3 neighbors is parallel to the camera ray. + depth_rejected_indices = (sampled_points_depth < depth_cutoff).reshape(-1) | np.isnan(sampled_points_cameraframe).any(axis=1) | cosine_culled_indices.reshape(-1) + # Transform the sampled points to the world coordinate system. + sampled_points_world = transform_camera_to_world(np.concatenate((sampled_points_cameraframe[:,:2]*sampled_points_depth.reshape(-1,1), sampled_points_depth.reshape(-1,1)), axis=1), rotations_image, translations_image) + sampled_points_world = sampled_points_world[:, :3] + + augmented_count = 0 + depth_index = 0 + for node in leaf_nodes: + if node.unoccupied: + if depth_rejected_indices[depth_index]: + node.depth_interpolated = False + else: + node.sampled_point_depth = sampled_points_depth[depth_index] + node.sampled_point_world = sampled_points_world[depth_index] + node.depth_interpolated = True + node.sampled_point_neighbours_indices = original_indices[indices[depth_index]] + node.sampled_point_neighbours_uv = points3d_pix[indices[depth_index]] + augmented_count += 1 + depth_index += 1 + + return leaf_nodes, augmented_count + +def transform_camera_to_world(sampled_points_cameraframe, rotations_image, translations_image): + extrinsics_image = np.concatenate((np.array(rotations_image), np.array(translations_image).reshape(3,1)), axis=1) + extrinsics_4x4 = np.concatenate((extrinsics_image, np.array([0, 0, 0, 1]).reshape(1, 4)), axis=0) + sampled_points_cameraframe_homogeneous = np.concatenate((sampled_points_cameraframe, np.ones((sampled_points_cameraframe.shape[0], 1))), axis=1) + sampled_points_worldframe = np.linalg.inv(extrinsics_4x4) @ sampled_points_cameraframe_homogeneous.T + + return sampled_points_worldframe.T + +def pixelwise_rgb_diff(point_a, point_b, threshold=0.3): + return np.linalg.norm(point_a - point_b) / 255. > threshold + +def image_quadtree_augmentation(image_key, + image_dir, + colmap_cameras, + colmap_images, + colmap_points3D, + points3d, points3d_rgb, + intrinsics_camera, + rotations_image, + translations_image, + quadtree_std_threshold=7, + quadtree_min_pixel_size=5, + visibility_aware_culling=False): + image_name = colmap_images[image_key].name + image_path = os.path.join(image_dir, image_name) + image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) + image_grayscale = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + + quadtree_root = quadtree_decomposition(image_grayscale, quadtree_std_threshold, quadtree_min_pixel_size) + + leaf_nodes = [] + gather_leaf_nodes(quadtree_root, leaf_nodes) + # Project 3D points onto the image plane. + points3d_pix, points3d_depth = project_3d_to_2d(points3d, intrinsics_camera[colmap_images[image_key].camera_id], np.concatenate((np.array(rotations_image[image_key]), np.array(translations_image[image_key]).reshape(3,1)), axis=1)) + # Cull the points that are outside the camera frustum and have negative depth. + near_culled_indices = (points3d_pix[:, 0] < 0) | (points3d_pix[:, 0] >= image.shape[1]) | (points3d_pix[:, 1] < 0) | (points3d_pix[:, 1] >= image.shape[0]) | (points3d_depth.reshape(-1) < 0) | np.isnan(points3d_pix).any(axis=1) + points3d_pix_rgb_diff = np.zeros(points3d_pix.shape[0], dtype=np.bool_) + if visibility_aware_culling: + for i, point in enumerate(points3d_pix): + if near_culled_indices[i]: + continue + point_a = image[int(point[1]), int(point[0])] + point_b = points3d_rgb[i] + points3d_pix_rgb_diff[i] = pixelwise_rgb_diff(point_a, point_b) + near_culled_indices = near_culled_indices | points3d_pix_rgb_diff.reshape(-1) + # Check every node in the quadtree that contains projected 3D points. + leaf_nodes = check_points_in_quadtree(points3d_pix, points3d_depth, points3d_rgb, leaf_nodes, near_culled_indices) + # Sample points from the leaf nodes that are not occupied by projected 3D points. + sampled_points = [] + sampled_points_rgb = [] + for node in leaf_nodes: + if node.unoccupied: + node.sampled_point_uv = np.array([node.x0, node.y0]) + np.random.sample(2) * np.array([node.width, node.height]) + node.sampled_point_rgb = image[int(node.sampled_point_uv[1]), int(node.sampled_point_uv[0])] + sampled_points.append(node.sampled_point_uv) + sampled_points_rgb.append(node.sampled_point_rgb) + sampled_points = np.array(sampled_points) + sampled_points_rgb = np.array(sampled_points_rgb) + # Interpolate the depth of the sampled points from the nearest 3D points. + leaf_nodes, augmented_count = find_depth_from_nn(image, leaf_nodes, points3d_pix, points3d_depth, near_culled_indices, intrinsics_camera[colmap_images[image_key].camera_id], rotations_image[image_key], translations_image[image_key]) + + return quadtree_root, augmented_count + +def transform_sample_3d(image_key, root, colmap_images, colmap_cameras, intrinsics_camera, rotations_image, translations_image): + # Gather leaf nodes and transform the sampled 2D points to the 3D world coordinates. + leaf_nodes = [] + gather_leaf_nodes(root, leaf_nodes) + sample_points_imageframe = [] + sample_points_depth = [] + sample_points_rgb = [] + for node in leaf_nodes: + if node.unoccupied: + if node.depth_interpolated: + sample_points_imageframe.append(node.sampled_point_uv) + sample_points_depth.append(node.sampled_point_depth) + sample_points_rgb.append(node.sampled_point_rgb) + sample_points_imageframe = np.array(sample_points_imageframe) + sample_points_depth = np.array(sample_points_depth) + sample_points_rgb = np.array(sample_points_rgb) + sample_points_cameraframe = (np.linalg.inv(intrinsics_camera[colmap_images[image_key].camera_id]) @ np.concatenate((sample_points_imageframe, np.ones((sample_points_imageframe.shape[0], 1))), axis=1).T).T + sample_points_cameraframe = np.concatenate((sample_points_cameraframe[:,:2]*sample_points_depth.reshape(-1,1), sample_points_depth.reshape(-1, 1)), axis=1) + sample_points_worldframe = transform_camera_to_world(sample_points_cameraframe, rotations_image[image_key], translations_image[image_key]) + return sample_points_worldframe[:,:-1], sample_points_rgb + +def write_points3D_colmap_binary(points3D, xyz, rgb, file_path): + with open(file_path, "wb") as fid: + write_next_bytes(fid, len(points3D) + len(xyz), "Q") + for j, (_, pt) in enumerate(points3D.items()): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + id = points3D[max(points3D.keys())].id + 1 + print("starts from id=", id) + for i in range(len(xyz)): + write_next_bytes(fid, id + i, "Q") + write_next_bytes(fid, xyz[i].tolist(), "ddd") + write_next_bytes(fid, rgb[i].tolist(), "BBB") + write_next_bytes(fid, 0.0, "d") + write_next_bytes(fid, 1, "Q") + write_next_bytes(fid, [0, 0], "ii") + return i + id + +def compare_local_texture(patch1, patch2): + """ + Compare two patches using a Gaussian kernel. + If the patch size is not 3x3, apply zero padding. + """ + # Define a 3x3 Gaussian kernel (sigma=1.0) + gaussian_kernel = np.array([ + [0.077847, 0.123317, 0.077847], + [0.123317, 0.195346, 0.123317], + [0.077847, 0.123317, 0.077847] + ]) + + # Check the patch size and apply zero padding if it is not 3x3. + def pad_to_3x3(patch): + if patch.shape[:2] != (3,3): + padded = np.zeros((3,3,3) if len(patch.shape)==3 else (3,3)) + h, w = patch.shape[:2] + y_start = (3-h)//2 + x_start = (3-w)//2 + padded[y_start:y_start+h, x_start:x_start+w] = patch + return padded + return patch + + patch1 = pad_to_3x3(patch1) + patch2 = pad_to_3x3(patch2) + + if len(patch1.shape) == 3: # RGB image + # Apply weights to each channel and calculate the difference. + weighted_diff = np.sum([ + np.sum(gaussian_kernel * (patch1[:,:,c] - patch2[:,:,c])**2) + for c in range(3) + ]) + else: # Grayscale image + weighted_diff = np.sum(gaussian_kernel * (patch1 - patch2)**2) + + return np.sqrt(weighted_diff)/255. \ No newline at end of file diff --git a/utils/bundle_utils.py b/utils/bundle_utils.py new file mode 100644 index 0000000..5502b10 --- /dev/null +++ b/utils/bundle_utils.py @@ -0,0 +1,156 @@ +import random +import os +from colmap.scripts.python.read_write_model import * +import numpy as np +from collections import defaultdict +from colmap_utils import compute_extrinsics, compute_intrinsics, get_colmap_data +import cv2 +from sklearn.decomposition import PCA + + + +def build_covisibility_matrix(images, points3D): + n_images = len(images) + covisibility_matrix = np.zeros((n_images, n_images)) + + id_to_idx = {img_id: idx for idx, img_id in enumerate(images.keys())} + idx_to_id = {idx: img_id for img_id, idx in id_to_idx.items()} + + for point3D in points3D.values(): + image_ids = point3D.image_ids + for i in range(len(image_ids)): + for j in range(i+1, len(image_ids)): + id1, id2 = image_ids[i], image_ids[j] + idx1, idx2 = id_to_idx[id1], id_to_idx[id2] + covisibility_matrix[idx1, idx2] += 1 + covisibility_matrix[idx2, idx1] += 1 + + return covisibility_matrix, id_to_idx, idx_to_id + +def create_covisibility_graph(covisibility_matrix, idx_to_id): + graph = defaultdict(dict) + for i in range(len(covisibility_matrix)): + for j in range(len(covisibility_matrix)): + if i != j and covisibility_matrix[i,j] > 0: + id1 = idx_to_id[i] + id2 = idx_to_id[j] + graph[id1][id2] = covisibility_matrix[i,j] + + return graph + +def create_sequence_from_covisibility_graph(covisibility_graph, min_covisibility=20): + if not covisibility_graph: + print("No covisibility graph found") + return [] + + start_node = max(covisibility_graph.keys(), + key=lambda k: sum(1 for v in covisibility_graph[k].values() if v >= min_covisibility)) + + visited = set([start_node]) + sequence = [start_node] + current = start_node + + while len(sequence) < len(covisibility_graph): + next_node = None + max_covisibility = -1 + + for neighbor, covisibility in covisibility_graph[current].items(): + if neighbor not in visited and covisibility > max_covisibility and covisibility >= min_covisibility: + next_node = neighbor + max_covisibility = covisibility + + if next_node is None: + for node in covisibility_graph: + if node not in visited: + for seq_node in sequence: + covisibility = covisibility_graph[node].get(seq_node, 0) + if covisibility > max_covisibility and covisibility >= min_covisibility: + max_covisibility = covisibility + next_node = node + + if next_node is None: + next_node = min(set(covisibility_graph.keys()) - visited) + + current = next_node + visited.add(current) + sequence.append(current) + + return sequence + +def cluster_cameras(model_path, camera_order): + colmap_path = os.path.join(model_path, 'sparse/0') + colmap_images, colmap_points3D, colmap_cameras = get_colmap_data(colmap_path) + if camera_order == 'covisibility': + covisibility_matrix, id_to_idx, idx_to_id = build_covisibility_matrix(colmap_images, colmap_points3D) + covisibility_graph = create_covisibility_graph(covisibility_matrix, idx_to_id) + covisibility_sequence = create_sequence_from_covisibility_graph(covisibility_graph) + + image_id_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()] + image_id_name_sorted = sorted(image_id_name, key=lambda x: x[1]) + test_id = [] + train_id = [] + count = 0 + train_only_idx = [] + for i, (id, name) in enumerate(image_id_name_sorted): + if i % 8 == 0: + test_id.append(id) + else: + train_id.append(id) + train_only_idx.append(count) + count+=1 + + rotations_image, translations_image = compute_extrinsics(colmap_images) + + train_only_visibility_idx = [] + for id in covisibility_sequence: + if id in train_id: + train_only_visibility_idx.append(id) + train_only_visibility_idx = np.array(train_only_visibility_idx) + sorted_keys = train_only_visibility_idx # sorted_indices 대신 sorted_keys로 할당 + + elif camera_order == 'PCA': + image_idx_name = [[colmap_images[key].id, colmap_images[key].name] for key in colmap_images.keys()] + image_idx_name_sorted = sorted(image_idx_name, key=lambda x: x[1]) + test_idx = [] + train_idx = [] + for i, (idx, name) in enumerate(image_idx_name_sorted): + if i % 8 == 0: + test_idx.append(idx) + else: + train_idx.append(idx) + + rotations_image, translations_image = compute_extrinsics(colmap_images) + + cam_center = [] + key = [] + for idx in train_idx: + cam_center.append((-rotations_image[idx].T @ translations_image[idx].reshape(3,1))) + key.append(idx) + + cam_center = np.array(cam_center)[:,:,0] + pca = PCA(n_components=2) + cam_center_2d = pca.fit_transform(cam_center) + + center_cam_center = np.mean(cam_center_2d, axis=0) + centered_cam_center = cam_center_2d - center_cam_center + angles = np.arctan2(centered_cam_center[:, 1], centered_cam_center[:, 0]) + sorted_indices = np.argsort(angles) + sorted_cam_centers = cam_center_2d[sorted_indices] + sorted_keys = np.array(key)[sorted_indices] + + return sorted_keys + +def bundle_start_index_generator(sorted_keys, initial_interval): + start_indices = [] + cluster_sizes = [] + for i in range(200): + start_indices.append((initial_interval*i)%len(sorted_keys)) + cluster_sizes.append(initial_interval) + + return start_indices, cluster_sizes + +def adaptive_cluster(start_idx, sorted_keys, cluster_size = 40, offset = 0): + idx = start_idx + indices = [sorted_keys[index % len(sorted_keys)] for index in range(idx, idx + cluster_size)] + random_index = random.choice(indices) + return random_index diff --git a/utils/colmap_utils.py b/utils/colmap_utils.py new file mode 100644 index 0000000..6cb3502 --- /dev/null +++ b/utils/colmap_utils.py @@ -0,0 +1,78 @@ +from colmap.scripts.python.read_write_model import * + +def get_colmap_data(dataset_path): + """ + Load COLMAP data from the given dataset path. + + Args: + dataset_path (str): Path to the dataset directory containing images, sparse/0, and dense/0 folders. + + Returns: + images: colmap image infos + points3D: colmap 3D points + cameras: colmap camera infos + """ + images = read_images_binary(os.path.join(dataset_path, 'images.bin')) + points3D = read_points3D_binary(os.path.join(dataset_path, 'points3D.bin')) + cameras = read_cameras_binary(os.path.join(dataset_path, 'cameras.bin')) + return images, points3D, cameras + +def quaternion_rotation_matrix(qw, qx, qy, qz): + """ + Convert a quaternion to a rotation matrix. + colmap uses wxyz order for quaternions. + """ + # First row of the rotation matrix + r00 = 2 * (qw * qw + qx * qx) - 1 + r01 = 2 * (qx * qy - qw * qz) + r02 = 2 * (qw * qy + qx * qz) + + # Second row of the rotation matrix + r10 = 2 * (qx * qy + qw * qz) + r11 = 2 * (qw * qw + qy * qy) - 1 + r12 = 2 * (qy * qz - qw * qx) + + # Third row of the rotation matrix + r20 = 2 * (qz * qx - qw * qy) + r21 = 2 * (qz * qy + qw * qx) + r22 = 2 * (qw * qw + qz * qz) - 1 + + # 3x3 rotation matrix + rot_matrix = np.array([[r00, r01, r02], + [r10, r11, r12], + [r20, r21, r22]]) + + return rot_matrix + +def compute_intrinsic_matrix(fx, fy, cx, cy, image_width, image_height): + sx = cx/(image_width/2) + sy = cy/(image_height/2) + intrinsic_matrix = np.array([[fx/sx, 0, cx/sx], [0, fy/sy, cy/sy], [0, 0, 1]]) + return intrinsic_matrix + +def compute_intrinsics(colmap_cameras, image_width, image_height): + intrinsics = {} + for cam_key in colmap_cameras.keys(): + intrinsic_parameters = colmap_cameras[cam_key].params + assert colmap_cameras[cam_key].model == 'PINHOLE' + intrinsic = compute_intrinsic_matrix(intrinsic_parameters[0], + intrinsic_parameters[1], + intrinsic_parameters[2], + intrinsic_parameters[3], + image_width, + image_height) + intrinsics[cam_key] = intrinsic + return intrinsics + +def compute_extrinsics(colmap_images): + rotations = {} + translations = {} + for image_key in colmap_images.keys(): + rotation = quaternion_rotation_matrix(colmap_images[image_key].qvec[0], + colmap_images[image_key].qvec[1], + colmap_images[image_key].qvec[2], + colmap_images[image_key].qvec[3]) + translation = colmap_images[image_key].tvec + rotations[image_key] = rotation + translations[image_key] = translation + return rotations, translations \ No newline at end of file diff --git a/utils/loss_utils.py b/utils/loss_utils.py index 60cf1f7..4cd5b8f 100644 --- a/utils/loss_utils.py +++ b/utils/loss_utils.py @@ -10,6 +10,7 @@ # import torch +import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from math import exp @@ -17,6 +18,7 @@ try: from diff_gaussian_rasterization._C import fusedssim, fusedssim_backward except: pass +import math C1 = 0.01 ** 2 C2 = 0.03 ** 2 @@ -89,3 +91,133 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): def fast_ssim(img1, img2): ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2) return ssim_map.mean() + +def build_gaussian_kernel(kernel_size=5, sigma=1.0, channels=3): + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_coord = torch.arange(kernel_size) + x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + + mean = (kernel_size - 1)/2. + variance = sigma**2. + + # Calculate the 2-dimensional gaussian kernel + gaussian_kernel = (1./(2.*math.pi*variance)) * \ + torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / \ + (2*variance)) + + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2D depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) + + return gaussian_kernel + +def gaussian_blur(x, kernel): + """Apply gaussian blur to input tensor""" + padding = (kernel.shape[-1] - 1) // 2 + return F.conv2d(x, kernel, padding=padding, groups=x.shape[1]) + +def create_laplacian_pyramid(image, max_levels=4): + """Create Laplacian pyramid for an image""" + pyramids = [] + current = image + kernel = build_gaussian_kernel().to(image.device) + + for _ in range(max_levels): + # Blur and downsample + blurred = gaussian_blur(current, kernel) + down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False) + + # Upsample and subtract + up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False) + laplace = current - up + + pyramids.append(laplace) + current = down + + pyramids.append(current) # Add the final residual + return pyramids + +def laplacian_pyramid_loss(pred, target, max_levels=4, weights=None): + """Compute Laplacian Pyramid Loss between predicted and target images""" + if weights is None: + weights = [1.0] * (max_levels + 1) + + pred_pyramids = create_laplacian_pyramid(pred, max_levels) + target_pyramids = create_laplacian_pyramid(target, max_levels) + + loss = 0 + for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights): + loss += weight * torch.abs(pred_lap - target_lap).mean() + + return loss + +class LaplacianPyramidLoss(torch.nn.Module): + def __init__(self, max_levels=4, channels=3, kernel_size=5, sigma=1.0): + super().__init__() + self.max_levels = max_levels + self.kernel = build_gaussian_kernel(kernel_size, sigma, channels) + + def forward(self, pred, target, weights=None): + if weights is None: + weights = [1.0] * (self.max_levels + 1) + + # Move kernel to the same device as input + kernel = self.kernel.to(pred.device) + + pred_pyramids = self.create_laplacian_pyramid(pred, kernel) + target_pyramids = self.create_laplacian_pyramid(target, kernel) + + loss = 0 + for pred_lap, target_lap, weight in zip(pred_pyramids, target_pyramids, weights): + loss += weight * torch.abs(pred_lap - target_lap).mean() + + return loss + + @staticmethod + def create_laplacian_pyramid(image, kernel, max_levels=4): + pyramids = [] + current = image + + for _ in range(max_levels): + # Apply Gaussian blur before downsampling to prevent aliasing + blurred = gaussian_blur(current, kernel) + down = F.interpolate(blurred, scale_factor=0.5, mode='bilinear', align_corners=False) + + # Upsample and subtract from the original image + up = F.interpolate(down, size=current.shape[2:], mode='bilinear', align_corners=False) + laplace = current - gaussian_blur(up, kernel) # Apply blur to upsampled image + + pyramids.append(laplace) + current = down + + pyramids.append(current) # Add the final residual + return pyramids + +class InvDepthSmoothnessLoss(nn.Module): + def __init__(self, alpha=10): + super(InvDepthSmoothnessLoss, self).__init__() + self.alpha = alpha # 엣지 가중치 강도를 조절하는 하이퍼파라미터 + + def forward(self, inv_depth, image): + # 역깊이 맵의 그래디언트 계산 + dx_inv_depth = torch.abs(inv_depth[:, :, :-1] - inv_depth[:, :, 1:]) + dy_inv_depth = torch.abs(inv_depth[:, :-1, :] - inv_depth[:, 1:, :]) + + # 이미지의 그래디언트 계산 + dx_image = torch.mean(torch.abs(image[:, :, :-1] - image[:, :, 1:]), 1, keepdim=True) + dy_image = torch.mean(torch.abs(image[:, :-1, :] - image[:, 1:, :]), 1, keepdim=True) + + # 이미지 그래디언트에 기반한 가중치 계산 + weight_x = torch.exp(-self.alpha * dx_image) + weight_y = torch.exp(-self.alpha * dy_image) + + # Smoothness loss 계산 + smoothness_x = dx_inv_depth * weight_x + smoothness_y = dy_inv_depth * weight_y + + return torch.mean(smoothness_x) + torch.mean(smoothness_y) \ No newline at end of file