From 34bbab0e60230914f39f81e869719c04cdcc3c87 Mon Sep 17 00:00:00 2001 From: homee-dennis Date: Mon, 25 Mar 2024 09:24:47 +0000 Subject: [PATCH] 3DGS on online and offline ARKit dataset --- arguments/__init__.py | 1 + arkit_utils/arkit_pose2obj.py | 157 +++++++++++++++ arkit_utils/arkit_pose_to_colmap.py | 77 ++++++++ .../mesh_to_points3D/arkitmeshply2point3D.py | 69 +++++++ .../mesh_to_points3D/texture_obj2point3D.py | 63 ++++++ arkit_utils/pose2tum_evo.py | 182 ++++++++++++++++++ .../pose_optimization/optimize_pose_colmap.py | 182 ++++++++++++++++++ .../pose_optimization/optimize_pose_hloc.py | 141 ++++++++++++++ arkit_utils/undistort_images/undistort_gif.py | 37 ++++ .../undistort_images/undistort_image.py | 172 +++++++++++++++++ .../undistort_images/undistort_image_cuda.py | 176 +++++++++++++++++ convert.py | 1 + read_camera_binary.py | 5 + run_arkit_3dgs.sh | 34 ++++ scene/__init__.py | 2 +- scene/dataset_readers.py | 18 +- 16 files changed, 1308 insertions(+), 9 deletions(-) create mode 100644 arkit_utils/arkit_pose2obj.py create mode 100644 arkit_utils/arkit_pose_to_colmap.py create mode 100644 arkit_utils/mesh_to_points3D/arkitmeshply2point3D.py create mode 100644 arkit_utils/mesh_to_points3D/texture_obj2point3D.py create mode 100644 arkit_utils/pose2tum_evo.py create mode 100644 arkit_utils/pose_optimization/optimize_pose_colmap.py create mode 100644 arkit_utils/pose_optimization/optimize_pose_hloc.py create mode 100644 arkit_utils/undistort_images/undistort_gif.py create mode 100644 arkit_utils/undistort_images/undistort_image.py create mode 100644 arkit_utils/undistort_images/undistort_image_cuda.py create mode 100644 read_camera_binary.py create mode 100644 run_arkit_3dgs.sh diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a55..0ea6966 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -48,6 +48,7 @@ class ModelParams(ParamGroup): def __init__(self, parser, sentinel=False): self.sh_degree = 3 self._source_path = "" + self._type = "" self._model_path = "" self._images = "images" self._resolution = -1 diff --git a/arkit_utils/arkit_pose2obj.py b/arkit_utils/arkit_pose2obj.py new file mode 100644 index 0000000..e583f2f --- /dev/null +++ b/arkit_utils/arkit_pose2obj.py @@ -0,0 +1,157 @@ + +import numpy as np +import math +import argparse +import os + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def create_frustum_mesh(translation, quaternion): + """Creates data for a camera frustum given its pose and projection parameters. + + Args: + translation: 3D translation vector (x, y, z). + quaternion: 4D quaternion representing camera rotation (w, x, y, z). + fov: Field of view angle in degrees. + aspect_ratio: Aspect ratio of the frustum's image plane. + near: Near clipping plane distance. + far: Far clipping plane distance. + + Returns: + A tuple containing vertex and face data for the frustum. + """ + + # Convert quaternion to rotation matrix + # world frame : y-up(gravity align), x-right + # camera frame : y-up, x-right, z-point to user + Rwc = qvec2rotmat(quaternion) + twc = translation + # Calculate frustum corner points in camera space + w = 0.128/2 + h = 0.072/2 + s = 0.1/2 + top_left = [-w, -h, -s] # nagative s due to the z axis point to the user + top_right= [w, -h, -s] + bottom_right = [w, h, -s] + bottom_left = [-w, h, -s] + + # Transform corner points to world space + world_top_left = Rwc.dot(top_left) + twc + world_top_right = Rwc.dot(top_right) + twc + world_bottom_right = Rwc.dot(bottom_right) + twc + world_bottom_left = Rwc.dot(bottom_left) + twc + world_near_center = twc + + # Create vertex and face data for the frustum + vertices = [ + "v " + " ".join([str(x) for x in world_top_left]), + "v " + " ".join([str(x) for x in world_top_right]), + "v " + " ".join([str(x) for x in world_bottom_right]), + "v " + " ".join([str(x) for x in world_bottom_left]), + "v " + " ".join([str(x) for x in world_near_center]) + + ] + faces = [ + "f 1 2 3 4", # Front face + "f 1 2 5", # Left side face + "f 1 4 5", # Bottom side face + "f 5 4 3", # Right side face + "f 2 3 5", # Back face + ] + return vertices, faces + +def write_multi_frustum_obj(xyzs, qxyzs, filename="multi_frustums.obj"): + """Writes camera frustums from multiple poses into a single .obj file. + + Args: + camera_poses: List of dictionaries with "translation" and "quaternion" keys. + fov: Field of view angle in degrees. + aspect_ratio: Aspect ratio of the frustum's image plane. + near: Near clipping plane distance. + far: Far clipping plane distance. + filename: Output filename for the .obj file. + """ + + # obj_data = "" + + for i, pose in enumerate(xyzs): + translation = xyzs[i] + quaternion = qxyzs[i] + vertices, faces = create_frustum_mesh(translation, quaternion) + obj_data = "\n".join(vertices) + "\n"+ "\n".join(faces) + # # Offset vertex indices based on existing vertices + # offset = len(obj_data.split("v ")) - 1 # Number of existing vertices + # for v in vertices: + # print(v) + # for x in v: + # print(x) + # new_vertices = [" ".join([str(float(x) + offset) for x in v]) for v in vertices] + # print(new_vertices) + # obj_data += "v " + "\n".join(new_vertices) + "\n" + + # # Keep face indices as-is since they point to relative vertex positions + # new_faces = ["f " + f for f in faces] + # obj_data += "f " + "\n".join(new_faces) + "\n" + + # Write complete obj data to file + with open(filename+str(i)+"image.obj", "w") as f: + f.write(obj_data) + +def readPose2buffer(path): + num_frames = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_frames += 1 + + print(f"num of frames : {num_frames}") + xyzs = np.empty((num_frames, 3)) + qxyzs = np.empty((num_frames, 4)) + + count = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 : + elems = line.split() + qxyz = np.array(tuple(map(float, elems[1:5]))) + xyz = np.array(tuple(map(float, elems[5:8]))) + qxyzs[count] = qxyz + xyzs[count] = xyz + count+=1 + + return xyzs, qxyzs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="transform ARKit pose to obj for meshLab visulization") + parser.add_argument("--input_cameras_path", type=str) + parser.add_argument("--output_frustum_path", type=str) + args = parser.parse_args() + input_cameras_path = args.input_cameras_path + output_frustum_path = args.output_frustum_path + + if not os.path.exists(output_frustum_path): + os.makedirs(output_frustum_path) + + xyzs, qxyzs = readPose2buffer(input_cameras_path) + write_multi_frustum_obj(xyzs, qxyzs, output_frustum_path) + + \ No newline at end of file diff --git a/arkit_utils/arkit_pose_to_colmap.py b/arkit_utils/arkit_pose_to_colmap.py new file mode 100644 index 0000000..70320a5 --- /dev/null +++ b/arkit_utils/arkit_pose_to_colmap.py @@ -0,0 +1,77 @@ +import argparse +import numpy as np +import os +from pyquaternion import Quaternion +from hloc.utils.read_write_model import Image, write_images_text + +def convert_pose(C2W): + flip_yz = np.eye(4) + flip_yz[1, 1] = -1 + flip_yz[2, 2] = -1 + C2W = np.matmul(C2W, flip_yz) + return C2W + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def arkit_pose_to_colmap(dataset_base) : + images = {} + with open(dataset_base + "/sparse/0/images.txt", "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + + c2w = np.zeros((4, 4)) + c2w[:3, :3] = qvec2rotmat(qvec) + c2w[:3, 3] = tvec + c2w[3, 3] = 1.0 + c2w_cv = convert_pose(c2w) + + # transform to z-up world coordinate for better visulazation + c2w_cv = np.array([[1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) @ c2w_cv + w2c_cv = np.linalg.inv(c2w_cv) + R = w2c_cv[:3, :3] + q = Quaternion(matrix=R, atol=1e-06) + qvec = np.array([q.w, q.x, q.y, q.z]) + tvec = w2c_cv[:3, -1] + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + + write_images_text(images=images, path=dataset_base+"/post/sparse/online/images.txt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize ARkit pose using hloc and COLMAP") + parser.add_argument("--input_database_path", type=str) + + args = parser.parse_args() + input_database_path = args.input_database_path + arkit_pose_to_colmap(input_database_path) \ No newline at end of file diff --git a/arkit_utils/mesh_to_points3D/arkitmeshply2point3D.py b/arkit_utils/mesh_to_points3D/arkitmeshply2point3D.py new file mode 100644 index 0000000..d9fbe0c --- /dev/null +++ b/arkit_utils/mesh_to_points3D/arkitmeshply2point3D.py @@ -0,0 +1,69 @@ + +import numpy as np +import argparse + +def rotx(t): + ''' 3D Rotation about the x-axis. ''' + c = np.cos(t) + s = np.sin(t) + return np.array([[1, 0, 0], + [0, c, -s], + [0, s, c]]) + +def transformARkitPCL2COLMAPpoint3DwithZaxisUpward(input_ply_path, output_ply_path): + find_start_row = False + num_points = 0 + with open(input_ply_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0: + elems = line.split() + if find_start_row == False and elems[0] == "end_header": + find_start_row = True + continue + if find_start_row and len(elems)==3: + num_points += 1 + print(f"total num of point cloud : {num_points}") + xyzs = np.empty((num_points, 3)) + rgbs = np.zeros((num_points, 3)) + + count = 0 + + find_start_row = False # reset + + with open(input_ply_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 : + elems = line.split() + if find_start_row == False and elems[0] == "end_header": + find_start_row = True + continue + if find_start_row and len(elems)==3: + xyz = np.array(tuple(map(float, elems[0:3]))) + # rotated y-up world frame to z-up world frame + xyz = rotx(np.pi / 2) @ xyz + xyzs[count] = xyz + count+=1 + + + with open(output_ply_path, "w") as f: + for i in range(num_points): + line = str(i) + " " + str(xyzs[i][0]) + " " + str(xyzs[i][1])+ " " + str(xyzs[i][2])+ " " + str(int(rgbs[i][0])) + " " + str(int(rgbs[i][1]))+ " " + str(int(rgbs[i][2]))+ " " + str(0) + f.write(line + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="transform ARKit mesh point cloud to COLMAP point3D format with z-up coordinate") + parser.add_argument("--input_base_path", type=str, default="data/homee/colmap") + + args = parser.parse_args() + input_ply_path = args.input_base_path + "/sparse/ARKitmesh.ply" + output_ply_path = args.input_base_path + "/post/sparse/online/points3D.txt" + + transformARkitPCL2COLMAPpoint3DwithZaxisUpward(input_ply_path, output_ply_path) \ No newline at end of file diff --git a/arkit_utils/mesh_to_points3D/texture_obj2point3D.py b/arkit_utils/mesh_to_points3D/texture_obj2point3D.py new file mode 100644 index 0000000..535a278 --- /dev/null +++ b/arkit_utils/mesh_to_points3D/texture_obj2point3D.py @@ -0,0 +1,63 @@ + +import numpy as np +import argparse + +def rotx(t): + ''' 3D Rotation about the x-axis. ''' + c = np.cos(t) + s = np.sin(t) + return np.array([[1, 0, 0], + [0, c, -s], + [0, s, c]]) + + +def transformARkitRgbPCL2COLMAPpoint3DwithRgbAndZaxisUpward(input_obj_path, output_ply_path): + num_points = 0 + with open(input_obj_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + if elems[0] == "v": + num_points += 1 + print(num_points) + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + + count = 0 + with open(input_obj_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] == "v": + elems = line.split() + if elems[0] == "v": + xyz = np.array(tuple(map(float, elems[1:4]))) + # rotated y-up world frame to z-up world frame + xyz = rotx(np.pi / 2) @ xyz + xyzs[count] = xyz + rgb = np.array(tuple(map(float, elems[4:7]))) + rgbs[count] = rgb*255 + count+=1 + + + with open(output_ply_path, "w") as f: + for i in range(num_points): + line = str(i) + " " + str(xyzs[i][0]) + " " + str(xyzs[i][1])+ " " + str(xyzs[i][2])+ " " + str(int(rgbs[i][0])) + " " + str(int(rgbs[i][1]))+ " " + str(int(rgbs[i][2]))+ " " + str(0) + f.write(line + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="transform ARKit texture mesh point cloud to COLMAP point3D format with RGB value and z-up coordinate") + parser.add_argument("--input_obj_path", type=str, default="data/homee/colmap/3dgs.obj") + parser.add_argument("--output_ply_path", type=str, default="data/homee/colmap/point3D.txt") + + args = parser.parse_args() + input_obj_path = args.input_obj_path + output_ply_path = args.output_ply_path + + transformARkitRgbPCL2COLMAPpoint3DwithRgbAndZaxisUpward(input_obj_path, output_ply_path) \ No newline at end of file diff --git a/arkit_utils/pose2tum_evo.py b/arkit_utils/pose2tum_evo.py new file mode 100644 index 0000000..1144b86 --- /dev/null +++ b/arkit_utils/pose2tum_evo.py @@ -0,0 +1,182 @@ +import numpy as np +import struct + +def convert_pose(C2W): + flip_yz = np.eye(4) + flip_yz[1, 1] = -1 + flip_yz[2, 2] = -1 + C2W = np.matmul(C2W, flip_yz) + return C2W +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +def read_pose_txt(path): + txt_path = path + "images.txt" + num_frames = 0 + with open(txt_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_frames += 1 + + print(f"num of frames : {num_frames}") + xyzs = np.empty((num_frames, 3)) + qxyzs = np.empty((num_frames, 4)) + + count = 0 + with open(txt_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + qxyz = np.array(tuple(map(float, elems[1:5]))) + xyz = np.array(tuple(map(float, elems[5:8]))) + + Twc = np.zeros((4, 4)) + Twc[:3, :3] = qvec2rotmat(qxyz) + Twc[:3, 3] = xyz + Twc[3, 3] = 1.0 + Twc = convert_pose(Twc) + Twc = np.array([[1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) @ Twc + + R = Twc[:3, :3] + qvec = rotmat2qvec(R) + tvec = Twc[:3, -1] + + qxyzs[count] = qxyz + xyzs[count] = xyz + count+=1 + + write_2_TUM_format(num_frames, xyzs, qxyzs, path+"est_tum.txt") + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_pose_bin(path): + num_frames = 0 + bin_path = path + "images.bin" + with open(bin_path, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + num_frames += 1 + print(f"num of frames : {num_frames}") + xyzs = np.empty((num_frames, 3)) + qxyzs = np.empty((num_frames, 4)) + + count = 0 + + with open(bin_path, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + + # COLMAP pose is in Tcw, we need Twc + Tcw = np.zeros((4, 4)) + Tcw[:3, :3] = qvec2rotmat(qvec) + Tcw[:3, 3] = tvec + Tcw[3, 3] = 1.0 + + Twc = np.linalg.inv(Tcw) + R = Twc[:3, :3] + qvec = rotmat2qvec(R) + tvec = Twc[:3, -1] + + # binary won't read as increasing order + qxyzs[image_id-1] = qvec + xyzs[image_id-1] = tvec + count+=1 + + write_2_TUM_format(num_frames, xyzs, qxyzs, path+"gt_tum.txt") + + +def write_2_TUM_format(n, xyzs, qxyzs, path): + ''' + tum expect pose in Twc (camera to world) + ''' + with open(path, "w") as f: + for i in range(n): + line = str(i) + " " + str(xyzs[i][0]) + " " + str(xyzs[i][1])+ " " + str(xyzs[i][2])+ " " + str(qxyzs[i][1])+ " " + str(qxyzs[i][2])+ " " + str(qxyzs[i][3]) + " " + str(qxyzs[i][0]) + f.write(line + "\n") + +if __name__ == "__main__": + # parser = argparse.ArgumentParser(description="transform ARKit pose to obj for meshLab visulization") + # parser.add_argument("--input_cameras_path", type=str) + # args = parser.parse_args() + + # input_cameras_path = args.input_cameras_path + + # read_pose_txt("data/arkit_pose/meeting_room_loop_closure/arkit_colmap2/colmap_arkit/raw/") + read_pose_bin("data/arkit_pose/meeting_room_loop_closure/arkit_colmap/colmap_arkit/raw/colmap_ba/") + diff --git a/arkit_utils/pose_optimization/optimize_pose_colmap.py b/arkit_utils/pose_optimization/optimize_pose_colmap.py new file mode 100644 index 0000000..34cfe30 --- /dev/null +++ b/arkit_utils/pose_optimization/optimize_pose_colmap.py @@ -0,0 +1,182 @@ +import argparse +import logging +import numpy as np +import json +import os +import pycolmap +import shutil +from pyquaternion import Quaternion +from hloc.triangulation import create_db_from_model +from pathlib import Path +from hloc.utils.read_write_model import Camera, Image, Point3D, CAMERA_MODEL_NAMES +from hloc.utils.read_write_model import write_model, read_model +from hloc import extract_features, match_features, pairs_from_poses, triangulation + +def convert_pose(C2W): + flip_yz = np.eye(4) + flip_yz[1, 1] = -1 + flip_yz[2, 2] = -1 + C2W = np.matmul(C2W, flip_yz) + return C2W + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def arkit_transform2_COLMAP(dataset_base) : + dataset_dir = Path(dataset_base) + + # step1. Transorm ARKit images to COLAMP coordinate + images = {} + with open(dataset_base + "/sparse/0/images.txt", "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + + c2w = np.zeros((4, 4)) + c2w[:3, :3] = qvec2rotmat(qvec) + c2w[:3, 3] = tvec + c2w[3, 3] = 1.0 + c2w_cv = convert_pose(c2w) + + # transform to z-up world coordinate for better visulazation + c2w_cv = np.array([[1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) @ c2w_cv + w2c_cv = np.linalg.inv(c2w_cv) + R = w2c_cv[:3, :3] + q = Quaternion(matrix=R, atol=1e-06) + qvec = np.array([q.w, q.x, q.y, q.z]) + tvec = w2c_cv[:3, -1] + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + + + # step2. Write ARKit undistorted intrinsic to COLMAP cameras + cameras = {} + with open(dataset_base + "/sparse/0/cameras.txt", "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + + + # Empty 3D points. + points3D = {} + + print('Writing the COLMAP model...') + colmap_arkit = dataset_dir / 'colmap_arkit' / 'raw' + colmap_arkit.mkdir(exist_ok=True, parents=True) + write_model(images=images, cameras=cameras, points3D=points3D, path=str(colmap_arkit), ext='.txt') + + + +def optimize_pose_by_COLMAP(dataset_base) : + feat_extracton_cmd = "colmap feature_extractor \ + --database_path " + dataset_base + "/database.db \ + --image_path " + dataset_base + "/images \ + --ImageReader.single_camera 1 \ + --ImageReader.camera_model PINHOLE \ + --SiftExtraction.use_gpu 1" + exit_code = os.system(feat_extracton_cmd) + if exit_code != 0: + logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") + exit(exit_code) + + ## Feature matching + feat_matching_cmd = "colmap exhaustive_matcher \ + --database_path " + dataset_base + "/database.db \ + --SiftMatching.use_gpu 1" + exit_code = os.system(feat_matching_cmd) + if exit_code != 0: + logging.error(f"Feature matching failed with code {exit_code}. Exiting.") + exit(exit_code) + + os.makedirs(dataset_base + "/colmap_arkit/tri", exist_ok=True) + triangulate_cmd = "colmap point_triangulator \ + --database_path " + dataset_base + "/database.db \ + --image_path " + dataset_base + "/images \ + --input_path " + dataset_base + "/colmap_arkit/raw \ + --output_path " + dataset_base + "/colmap_arkit/tri" + exit_code = os.system(triangulate_cmd) + if exit_code != 0: + logging.error(f"Point triangulation failed with code {exit_code}. Exiting.") + exit(exit_code) + + os.makedirs(dataset_base + "/colmap_arkit/ba", exist_ok=True) + BA_cmd = "colmap bundle_adjuster \ + --input_path " + dataset_base + "/colmap_arkit/tri \ + --output_path " + dataset_base + "/colmap_arkit/ba" + exit_code = os.system(BA_cmd) + if exit_code != 0: + logging.error(f"BA failed with code {exit_code}. Exiting.") + exit(exit_code) + + os.makedirs(dataset_base + "/colmap_arkit/tri2", exist_ok=True) + triangulate_cmd = "colmap point_triangulator \ + --database_path " + dataset_base + "/database.db \ + --image_path " + dataset_base + "/images \ + --input_path " + dataset_base + "/colmap_arkit/ba \ + --output_path " + dataset_base + "/colmap_arkit/tri2" + exit_code = os.system(triangulate_cmd) + if exit_code != 0: + logging.error(f"Point triangulation failed with code {exit_code}. Exiting.") + exit(exit_code) + + os.makedirs(dataset_base + "/colmap_arkit/ba2", exist_ok=True) + BA_cmd = "colmap bundle_adjuster \ + --input_path " + dataset_base + "/colmap_arkit/tri2 \ + --output_path " + dataset_base + "/colmap_arkit/ba2" + exit_code = os.system(BA_cmd) + if exit_code != 0: + logging.error(f"BA failed with code {exit_code}. Exiting.") + exit(exit_code) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize ARkit pose using COLMAP") + parser.add_argument("--input_database_path", type=str, default="data/arkit_pose/study_room/arkit_undis") + args = parser.parse_args() + + input_database_path = args.input_database_path + + arkit_transform2_COLMAP(input_database_path) + optimize_pose_by_COLMAP(input_database_path) diff --git a/arkit_utils/pose_optimization/optimize_pose_hloc.py b/arkit_utils/pose_optimization/optimize_pose_hloc.py new file mode 100644 index 0000000..98ef690 --- /dev/null +++ b/arkit_utils/pose_optimization/optimize_pose_hloc.py @@ -0,0 +1,141 @@ +import argparse +import numpy as np +import json +import os +import pycolmap +import shutil +from pyquaternion import Quaternion +from hloc.triangulation import create_db_from_model +from pathlib import Path +from hloc.utils.read_write_model import Camera, Image, Point3D, CAMERA_MODEL_NAMES +from hloc.utils.read_write_model import write_model, read_model +from hloc import extract_features, match_features, pairs_from_poses, triangulation + + +def prepare_pose_and_intrinsic_prior(dataset_base) : + dataset_dir = Path(dataset_base) + + # step1. Write ARKit pose (in COLMAP ccordinate) to COLMAP images + images = {} + with open(dataset_base + "/post/sparse/online/images.txt", "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + + # step2. Write ARKit undistorted intrinsic to COLMAP cameras + cameras = {} + with open(dataset_base + "/post/sparse/online/cameras.txt", "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + + + # Empty 3D points. + points3D = {} + + print('Writing the COLMAP model...') + colmap_arkit_base = dataset_dir / 'post' / 'sparse' /'offline' + colmap_arkit = colmap_arkit_base / 'raw' + colmap_arkit.mkdir(exist_ok=True, parents=True) + write_model(images=images, cameras=cameras, points3D=points3D, path=str(colmap_arkit), ext='.bin') + + return colmap_arkit + + + +def optimize_pose_by_hloc_and_COLMAP(dataset_base, n_ba_iterations, n_matched = 10) : + # step1. Extract feature using hloc + dataset_dir = Path(dataset_base) + colmap_arkit_base = dataset_dir / 'post' / 'sparse' /'offline' + colmap_arkit = colmap_arkit_base / 'raw' + outputs = colmap_arkit_base / 'hloc' + outputs.mkdir(exist_ok=True, parents=True) + + images = dataset_dir / 'post' / 'images' + sfm_pairs = outputs / 'pairs-sfm.txt' + features = outputs / 'features.h5' + matches = outputs / 'matches.h5' + + references = [str(p.relative_to(images)) for p in images.iterdir()] + feature_conf = extract_features.confs['superpoint_inloc'] + matcher_conf = match_features.confs['superglue'] + + extract_features.main(feature_conf, images, image_list=references, feature_path=features) + pairs_from_poses.main(colmap_arkit, sfm_pairs, n_matched) + match_features.main(matcher_conf, sfm_pairs, features=features, matches=matches) + + # step2. optimize pose + colmap_input = colmap_arkit + for i in range(n_ba_iterations): + colmap_sparse = outputs / 'colmap_sparse' + colmap_sparse.mkdir(exist_ok=True, parents=True) + reconstruction = triangulation.main( + colmap_sparse, # output model + colmap_input, # input model + images, + sfm_pairs, + features, + matches) + + colmap_ba = outputs / 'colmap_ba' + colmap_ba.mkdir(exist_ok=True, parents=True) + # BA with fix intinsics + BA_cmd = f'colmap bundle_adjuster \ + --BundleAdjustment.refine_focal_length 0 \ + --BundleAdjustment.refine_principal_point 0 \ + --BundleAdjustment.refine_extra_params 0 \ + --input_path {colmap_sparse} \ + --output_path {colmap_ba}' + os.system(BA_cmd) + + colmap_input = colmap_ba + + # step3. get ba result to outside folder + cameras, images, point3D = read_model(colmap_ba, ext=".bin") + write_model(cameras, images, point3D, colmap_arkit_base, ext=".txt") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Optimize ARkit pose using hloc and COLMAP") + parser.add_argument("--input_database_path", type=str, default="data/arkit_pose/study_room/arkit_undis") + parser.add_argument("--BA_iterations", type=int, default=5) + + args = parser.parse_args() + + input_database_path = args.input_database_path + ba_iterations = args.BA_iterations + prepare_pose_and_intrinsic_prior(input_database_path) + optimize_pose_by_hloc_and_COLMAP(input_database_path, ba_iterations) diff --git a/arkit_utils/undistort_images/undistort_gif.py b/arkit_utils/undistort_images/undistort_gif.py new file mode 100644 index 0000000..5bb87dd --- /dev/null +++ b/arkit_utils/undistort_images/undistort_gif.py @@ -0,0 +1,37 @@ +import cv2 +import imageio +import numpy as np +def main(): + # 讀取原始影像和 undistort 完畢的影像 + original_image = cv2.imread("data/homee/study_room_test/dis/0043.jpg") + undistorted_image = cv2.imread("data/homee/study_room_test/images/0043.jpg") + COLMAP_undistorted_image = cv2.imread("data/homee/study_room_colmap/images/0043.jpg") + + original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) + undistorted_image = cv2.cvtColor(undistorted_image, cv2.COLOR_BGR2RGB) + COLMAP_undistorted_image = cv2.cvtColor(COLMAP_undistorted_image, cv2.COLOR_BGR2RGB) + height, width, channel = undistorted_image.shape + print(f"height : {height}, width : {width}") + + text = "undistorted" + font_face = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 1.0 + font_thickness = 2 + text_color = (255, 0, 0) # Red color in RGB format + cv2.putText(undistorted_image, text, (100, 100), font_face, font_scale, text_color, font_thickness) + text = "COLMAP undistorted" + cv2.putText(COLMAP_undistorted_image, text, (100, 100), font_face, font_scale, text_color, font_thickness) + + # 設定 GIF 動畫的參數 + total_duration = 10 # 動畫播放時間 + + # 建立一個空的 GIF 動畫 + with imageio.get_writer("before_and_after.gif", mode="I", duration=1000) as writer: + for i in range(total_duration): + writer.append_data(original_image) + writer.append_data(undistorted_image) + writer.append_data(COLMAP_undistorted_image) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/arkit_utils/undistort_images/undistort_image.py b/arkit_utils/undistort_images/undistort_image.py new file mode 100644 index 0000000..a89eebd --- /dev/null +++ b/arkit_utils/undistort_images/undistort_image.py @@ -0,0 +1,172 @@ +import cv2 +import numpy as np +import json +import os +import argparse +from concurrent.futures import ProcessPoolExecutor + +class Data: + def __init__(self, intrinsic_matrix, intrinsic_matrix_reference_dimensions, lens_distortion_center, inverse_lens_distortion_lookup_table, lens_distortion_lookup_table): + self.intrinsic_matrix = intrinsic_matrix + self.intrinsic_matrix_reference_dimensions = intrinsic_matrix_reference_dimensions + self.lens_distortion_center = lens_distortion_center + self.inverse_lens_distortion_lookup_table = inverse_lens_distortion_lookup_table + self.lens_distortion_lookup_table = lens_distortion_lookup_table + +def readCalibrationJson(path): + # Open the JSON file + with open(path, "r") as f: + # Read the contents of the file + data = json.load(f) + + # Access specific data from the dictionary + pixel_size = data["calibration_data"]["pixel_size"] + intrinsic_matrix = data["calibration_data"]["intrinsic_matrix"] + intrinsic_matrix_reference_dimensions = data["calibration_data"]["intrinsic_matrix_reference_dimensions"] + lens_distortion_center = data["calibration_data"]["lens_distortion_center"] + # Access specific elements from lists within the dictionary + inverse_lut = data["calibration_data"]["inverse_lens_distortion_lookup_table"] + lut = data["calibration_data"]["lens_distortion_lookup_table"] + + data = Data(intrinsic_matrix, intrinsic_matrix_reference_dimensions, lens_distortion_center, inverse_lut, lut) + # # Print some of the data for verification + # print(f"Pixel size: {pixel_size}") + # print(f"Intrinsic matrix:\n {intrinsic_matrix}") + # print(f"Lens distortion center: {lens_distortion_center}") + # print(f"Inverse lookup table length: {len(inverse_lut)}") + return data + +def get_lens_distortion_point(point, lookup_table, distortion_center, image_size): + radius_max_x = min(distortion_center[0], image_size[0] - distortion_center[0]) + radius_max_y = min(distortion_center[1], image_size[1] - distortion_center[1]) + radius_max = np.sqrt(radius_max_x**2 + radius_max_y**2) + + radius_point = np.sqrt(np.square(point[0] - distortion_center[0]) + np.square(point[1] - distortion_center[1])) + + magnification = lookup_table[-1] + if radius_point < radius_max: + relative_position = radius_point / radius_max * (len(lookup_table) - 1) + frac = relative_position - np.floor(relative_position) + lower_lookup = lookup_table[int(np.floor(relative_position))] + upper_lookup = lookup_table[int(np.ceil(relative_position))] + magnification = lower_lookup * (1.0 - frac) + upper_lookup * frac + + mapped_point = np.array([distortion_center[0] + (point[0] - distortion_center[0]) * (1.0 + magnification), + distortion_center[1] + (point[1] - distortion_center[1]) * (1.0 + magnification)]) + return mapped_point + +def rectify_single_image(image_path, output_path, distortion_param_json_path, crop_x, crop_y): + """Processes a single image with distortion correction.""" + image = cv2.imread(image_path) + height, width, channel = image.shape + rectified_image = np.zeros((height, width, channel), dtype=image.dtype) + + # read calibration data + data = readCalibrationJson(distortion_param_json_path) + lookup_table = data.inverse_lens_distortion_lookup_table# data.lens_distortion_lookup_table + distortion_center = data.lens_distortion_center + reference_dimensions = data.intrinsic_matrix_reference_dimensions + ratio_x = width / reference_dimensions[0] + ratio_y = height / reference_dimensions[1] + distortion_center[0] = distortion_center[0] * ratio_x + distortion_center[1] = distortion_center[1] * ratio_y + + for i in range(width): + for j in range(height): + rectified_index = np.array([i, j]) + original_index = get_lens_distortion_point( + rectified_index, lookup_table, distortion_center, [width, height]) + + if (original_index[0] < 0 or original_index[0] >= width or original_index[1] < 0 or original_index[1] >= height): + continue + + rectified_image[j, i] = image[int(original_index[1]), int(original_index[0])] + + # crop image + u_shift = crop_x + v_shift = crop_y + crop_image = rectified_image[v_shift:height-v_shift, u_shift:width-u_shift] + cv2.imwrite(output_path, crop_image) + print(f"finish process {image_path}") + +def rectify_all_image(image_folder_path, distortion_param_json_path, output_image_folder_path, crop_x, crop_y): + with ProcessPoolExecutor() as executor: + for filename in os.listdir(image_folder_path): + image_path = os.path.join(image_folder_path, filename) + output_path = os.path.join(output_image_folder_path, filename) + executor.submit( + rectify_single_image, image_path, output_path, distortion_param_json_path, crop_x, crop_y) + + +def rectified_intrinsic(input_path, output_path, crop_x, crop_y): + num_intrinsic = 0 + with open(input_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_intrinsic += 1 + print(num_intrinsic) + + camera_ids = np.empty((num_intrinsic, 1)) + widths = np.empty((num_intrinsic, 1)) + heights = np.empty((num_intrinsic, 1)) + paramss = np.empty((num_intrinsic, 4)) + + count = 0 + with open(input_path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2])-crop_x*2 + height = int(elems[3])-crop_y*2 + params = np.array(tuple(map(float, elems[4:]))) + params[2] = params[2] - crop_x + params[3] = params[3] - crop_y + + camera_ids[count] = camera_id + widths[count] = width + heights[count] = height + paramss[count] = params + + count = count+1 + + with open(output_path, "w") as f: + for i in range(num_intrinsic): + line = str(int(camera_ids[i])) + " " + "PINHOLE" + " " + str(int(widths[i]))+ " " + str(int(heights[i]))+ " " + str(paramss[i][0]) + " " + str(paramss[i][1])+ " " + str(paramss[i][2])+ " " + str(paramss[i][3]) + f.write(line + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="undistort ARKit image using distortion params get from AVfoundation") + parser.add_argument("--input_base", type=str) + parser.add_argument("--crop_x", type=int, default=10) + parser.add_argument("--crop_y", type=int, default=8) + + + args = parser.parse_args() + base_folder_path = args.input_base + input_image_folder_path = base_folder_path + "/distort_images" + distortion_param_json_path = base_folder_path + "/sparse/0/calibration.json" + output_image_folder_path = base_folder_path + "/post/images/" + crop_x = args.crop_x + crop_y = args.crop_y + input_camera = base_folder_path + "/sparse/0/distort_cameras.txt" + output_camera = base_folder_path + "/post/sparse/online/cameras.txt" + + + if not os.path.exists(output_image_folder_path): + os.makedirs(output_image_folder_path) + + rectify_all_image(input_image_folder_path, distortion_param_json_path, output_image_folder_path, crop_x, crop_y) + rectified_intrinsic(input_camera, output_camera, crop_x, crop_y) + + \ No newline at end of file diff --git a/arkit_utils/undistort_images/undistort_image_cuda.py b/arkit_utils/undistort_images/undistort_image_cuda.py new file mode 100644 index 0000000..a893b6a --- /dev/null +++ b/arkit_utils/undistort_images/undistort_image_cuda.py @@ -0,0 +1,176 @@ +import cv2 +import numpy as np +import json +import os +import argparse +from pycuda import driver, gpuarray + +class Data: + def __init__(self, intrinsic_matrix, intrinsic_matrix_reference_dimensions, lens_distortion_center, inverse_lens_distortion_lookup_table, lens_distortion_lookup_table): + self.intrinsic_matrix = intrinsic_matrix + self.intrinsic_matrix_reference_dimensions = intrinsic_matrix_reference_dimensions + self.lens_distortion_center = lens_distortion_center + self.inverse_lens_distortion_lookup_table = inverse_lens_distortion_lookup_table + self.lens_distortion_lookup_table = lens_distortion_lookup_table + +def readCalibrationJson(path): + # Open the JSON file + with open(path, "r") as f: + # Read the contents of the file + data = json.load(f) + + # Access specific data from the dictionary + pixel_size = data["calibration_data"]["pixel_size"] + intrinsic_matrix = data["calibration_data"]["intrinsic_matrix"] + intrinsic_matrix_reference_dimensions = data["calibration_data"]["intrinsic_matrix_reference_dimensions"] + lens_distortion_center = data["calibration_data"]["lens_distortion_center"] + # Access specific elements from lists within the dictionary + inverse_lut = data["calibration_data"]["inverse_lens_distortion_lookup_table"] + lut = data["calibration_data"]["lens_distortion_lookup_table"] + + data = Data(intrinsic_matrix, intrinsic_matrix_reference_dimensions, lens_distortion_center, inverse_lut, lut) + return data + + +# Function to check for CUDA error +def check_cuda_error(err): + if err != 0: + driver.Context.synchronize() # Synchronize to ensure proper error handling + print("CUDA error:", driver.Error(err)) + exit(1) + +def get_lens_distortion_point(point, lookup_table, distortion_center, image_size): + radius_max_x = min(distortion_center[0], image_size[0] - distortion_center[0]) + radius_max_y = min(distortion_center[1], image_size[1] - distortion_center[1]) + radius_max = np.sqrt(radius_max_x**2 + radius_max_y**2) + + radius_point = np.sqrt(np.square(point[0] - distortion_center[0]) + np.square(point[1] - distortion_center[1])) + + magnification = lookup_table[-1] + if radius_point < radius_max: + relative_position = radius_point / radius_max * (len(lookup_table) - 1) + frac = relative_position - np.floor(relative_position) + lower_lookup = lookup_table[int(np.floor(relative_position))] + upper_lookup = lookup_table[int(np.ceil(relative_position))] + magnification = lower_lookup * (1.0 - frac) + upper_lookup * frac + + mapped_point = np.array([distortion_center[0] + (point[0] - distortion_center[0]) * (1.0 + magnification), + distortion_center[1] + (point[1] - distortion_center[1]) * (1.0 + magnification)]) + return mapped_point + +def process_image(image_path, output_path, distortion_param_json_path): + """Processes a single image with distortion correction.""" + image = cv2.imread(image_path) + height, width, channel = image.shape + rectified_image = np.zeros((height, width, channel), dtype=image.dtype) + + # read calibration data + data = readCalibrationJson(distortion_param_json_path) + lookup_table = data.lens_distortion_lookup_table + distortion_center = data.lens_distortion_center + reference_dimensions = data.intrinsic_matrix_reference_dimensions + ratio_x = width / reference_dimensions[0] + ratio_y = height / reference_dimensions[1] + distortion_center[0] = distortion_center[0] * ratio_x + distortion_center[1] = distortion_center[1] * ratio_y + + for i in range(width): + for j in range(height): + rectified_index = np.array([i, j]) + original_index = get_lens_distortion_point( + rectified_index, lookup_table, distortion_center, [width, height]) + + if (original_index[0] < 0 or original_index[0] >= width or original_index[1] < 0 or original_index[1] >= height): + continue + + rectified_image[j, i] = image[int(original_index[1]), int(original_index[0])] + + cv2.imwrite(output_path, rectified_image) + print(f"finish process {image_path}") + + +def process_image_cuda(image_path, output_path, distortion_param_json_path): + # Load image data on host (CPU) + image = cv2.imread(image_path) + height, width, channel = image.shape + + # read calibration data + data = readCalibrationJson(distortion_param_json_path) + lookup_table = data.inverse_lens_distortion_lookup_table + distortion_center = data.lens_distortion_center + reference_dimensions = data.intrinsic_matrix_reference_dimensions + ratio_x = width / reference_dimensions[0] + ratio_y = height / reference_dimensions[1] + distortion_center[0] = distortion_center[0] * ratio_x + distortion_center[1] = distortion_center[1] * ratio_y + + # Allocate memory on GPU for image data, lookup table, and rectified image + image_gpu = gpuarray.to_gpu(image.astype(np.float32)) + lookup_table_gpu = gpuarray.to_gpu(lookup_table.astype(np.float32)) + rectified_image_gpu = gpuarray.empty((height, width, channel), np.float32) + + # Prepare CUDA kernel code (replace with your actual kernel implementation) + kernel_code = """ + __global__ void process_image(float* image, float* lookup_table, float* distortion_center, float* rectified_image, int width, int height) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= width * height) { + return; + } + + int y = idx / width; + int x = idx % width; + + // Replace with your actual distortion correction logic using `get_lens_distortion_point` + float rectified_x, rectified_y; + get_lens_distortion_point(x, y, lookup_table, distortion_center, &rectified_x, &rectified_y); + + if (rectified_x < 0 || rectified_x >= width || rectified_y < 0 || rectified_y >= height) { + return; + } + + int rectified_idx = (int)rectified_y * width + (int)rectified_x; + rectified_image[rectified_idx] = image[idx]; + } + """ + + # Compile the kernel + mod = driver.SourceModule(kernel_code) + process_image = mod.get_function("process_image") + + # Set kernel parameters and launch + threads_per_block = (16, 16) # Adjust block size as needed + grid_size = (width // threads_per_block[0] + 1, height // threads_per_block[1] + 1) + check_cuda_error(process_image( + image_gpu, lookup_table_gpu, driver.In(distortion_center), rectified_image_gpu, np.int32(width), np.int32(height), block=threads_per_block, grid=grid_size + )) + + # Transfer rectified image back to host and convert to uint8 (assuming original image format) + rectified_image = rectified_image_gpu.get_array().astype(np.uint8) + + # Save the processed image + cv2.imwrite(output_path, rectified_image) + + # Free GPU memory + image_gpu.free() + lookup_table_gpu.free() + rectified_image_gpu.free() + +def rectify_image(image_folder_path, distortion_param_json_path, output_image_folder_path): + for filename in os.listdir(image_folder_path): + image_path = os.path.join(image_folder_path, filename) + output_path = os.path.join(output_image_folder_path, filename) + process_image_cuda(image_path, output_path, distortion_param_json_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="undistort ARKit image using distortion params get from AVfoundation") + parser.add_argument("--input", type=str) + parser.add_argument("--json", type=str) + parser.add_argument("--output", type=str) + + args = parser.parse_args() + input_image_folder_path = args.input + distortion_param_json_path = args.json + output_image_folder_path = args.output + + rectify_image(input_image_folder_path, distortion_param_json_path, output_image_folder_path) \ No newline at end of file diff --git a/convert.py b/convert.py index 7894884..7dbb614 100644 --- a/convert.py +++ b/convert.py @@ -71,6 +71,7 @@ img_undist_cmd = (colmap_command + " image_undistorter \ --image_path " + args.source_path + "/input \ --input_path " + args.source_path + "/distorted/sparse/0 \ --output_path " + args.source_path + "\ + --blank_pixels 0 \ --output_type COLMAP") exit_code = os.system(img_undist_cmd) if exit_code != 0: diff --git a/read_camera_binary.py b/read_camera_binary.py new file mode 100644 index 0000000..0ee909d --- /dev/null +++ b/read_camera_binary.py @@ -0,0 +1,5 @@ +from scene.colmap_loader import read_cameras_binary, read_points3D_binary, read_extrinsics_binary + +# read_cameras_binary("data/arkit_pose/meeting_room_loop_closure/arkit_colmap2/colmap_arkit/tri/cameras.bin") +# read_points3D_binary("data/arkit_pose/meeting_room_loop_closure/arkit_colmap2/colmap_arkit/tri/points3D.bin") +read_extrinsics_binary("data/arkit_pose/meeting_room_loop_closure/arkit_colmap2/colmap_arkit/tri/images.bin") \ No newline at end of file diff --git a/run_arkit_3dgs.sh b/run_arkit_3dgs.sh new file mode 100644 index 0000000..318d3fc --- /dev/null +++ b/run_arkit_3dgs.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -e + +# Validate the input argument +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +input_base_path=$1 +ba_iterations=$2 + +echo "input_base_path: ${input_base_path}" + +echo "=== Preprocess ARkit data === " +mkdir ${input_base_path}/post +mkdir ${input_base_path}/post/sparse/ +mkdir ${input_base_path}/post/sparse/online +echo "1. undistort image uisng AVfoundation calibration data" +python arkit_utils/undistort_images/undistort_image.py --input_base ${input_base_path} +echo "2. Transform ARKit mesh to point3D" +python arkit_utils/mesh_to_points3D/arkitmeshply2point3D.py --input_base_path ${input_base_path} +echo "3. Transform ARKit pose to COLMAP coordinate" +python arkit_utils/arkit_pose_to_colmap.py --input_database_path ${input_base_path} + +echo "3. Optimize pose using hloc & COLMAP" +mkdir ${input_base_path}/post/sparse/offline +python arkit_utils/pose_optimization/optimize_pose_hloc.py --input_database_path ${input_base_path} + +echo "=== 3D gaussian splatting === " +echo "1. 3DGS on online data" +python train.py -s ${input_base_path}/post -t online -m ${input_base_path}/post/sparse/online/output --iterations 7000 +# echo "1. 3DGS on offline data" +# CUDA_VISIBLE_DEVICS=1 python train.py -s ${input_base_path}/post -t offline -m ${input_base_path}/post/sparse/offline/output --iterations 7000 diff --git a/scene/__init__.py b/scene/__init__.py index 2b31398..f4a8c00 100644 --- a/scene/__init__.py +++ b/scene/__init__.py @@ -41,7 +41,7 @@ class Scene: self.test_cameras = {} if os.path.exists(os.path.join(args.source_path, "sparse")): - scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval) + scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.type, args.images, args.eval) elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): print("Found transforms_train.json file, assuming Blender data set!") scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval) diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 2a6f904..f78a675 100644 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -129,15 +129,17 @@ def storePly(path, xyz, rgb): ply_data = PlyData([vertex_element]) ply_data.write(path) -def readColmapSceneInfo(path, images, eval, llffhold=8): +def readColmapSceneInfo(path, type, images, eval, llffhold=8): + sparse_folder = os.path.join("sparse", type) + print(f"read sparse data from {sparse_folder}") try: - cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") - cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") + cameras_extrinsic_file = os.path.join(path, sparse_folder, "images.bin") + cameras_intrinsic_file = os.path.join(path, sparse_folder, "cameras.bin") cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) except: - cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") - cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") + cameras_extrinsic_file = os.path.join(path, sparse_folder, "images.txt") + cameras_intrinsic_file = os.path.join(path, sparse_folder, "cameras.txt") cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) @@ -154,9 +156,9 @@ def readColmapSceneInfo(path, images, eval, llffhold=8): nerf_normalization = getNerfppNorm(train_cam_infos) - ply_path = os.path.join(path, "sparse/0/points3D.ply") - bin_path = os.path.join(path, "sparse/0/points3D.bin") - txt_path = os.path.join(path, "sparse/0/points3D.txt") + ply_path = os.path.join(path, sparse_folder, "points3D.ply") + bin_path = os.path.join(path, sparse_folder, "points3D.bin") + txt_path = os.path.join(path, sparse_folder, "points3D.txt") if not os.path.exists(ply_path): print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") try: