mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-29 18:32:49 +00:00
Manage central point different from center of image
In some cases, calibration gives central point cx,cy != (0.5,0.5), or it can be decided to crop the input images. In those cases, it is necessary to split fovx to fovXleft,fovXright and fovy to fovYtop,fovYbottom Note that the export of cameras to cameras.json merges those values back to the basic fovx,fovy. This aims at avoiding the modification of diff_gaussian_rasterization branch used for SIBR_gaussianViewer_app. Signed-off-by: Matthieu Gendrin <matthieu.gendrin@orange.com>
This commit is contained in:
parent
b17ded92b5
commit
5db5c254f4
@ -30,8 +30,8 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
pass
|
||||
|
||||
# Set up rasterization configuration
|
||||
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
||||
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
||||
tanfovx = math.tan((viewpoint_camera.FovXright - viewpoint_camera.FovXleft) * 0.5)
|
||||
tanfovy = math.tan((viewpoint_camera.FovYbottom - viewpoint_camera.FovYtop) * 0.5)
|
||||
|
||||
raster_settings = GaussianRasterizationSettings(
|
||||
image_height=int(viewpoint_camera.image_height),
|
||||
|
@ -15,7 +15,7 @@ import numpy as np
|
||||
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
||||
|
||||
class Camera(nn.Module):
|
||||
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
||||
def __init__(self, colmap_id, R, T, FovXleft, FovXright, FovYtop, FovYbottom, image, gt_alpha_mask,
|
||||
image_name, uid,
|
||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
||||
):
|
||||
@ -25,8 +25,10 @@ class Camera(nn.Module):
|
||||
self.colmap_id = colmap_id
|
||||
self.R = R
|
||||
self.T = T
|
||||
self.FoVx = FoVx
|
||||
self.FoVy = FoVy
|
||||
self.FovXleft = FovXleft
|
||||
self.FovXright = FovXright
|
||||
self.FovYtop = FovYtop
|
||||
self.FovYbottom = FovYbottom
|
||||
self.image_name = image_name
|
||||
|
||||
try:
|
||||
@ -52,7 +54,7 @@ class Camera(nn.Module):
|
||||
self.scale = scale
|
||||
|
||||
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
||||
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
||||
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovXleft=FovXleft, fovXright=FovXright, fovYtop=FovYtop, fovYbottom=FovYbottom).transpose(0,1).cuda()
|
||||
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
||||
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
||||
|
||||
@ -60,8 +62,8 @@ class MiniCam:
|
||||
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
||||
self.image_width = width
|
||||
self.image_height = height
|
||||
self.FoVy = fovy
|
||||
self.FoVx = fovx
|
||||
self.FovYtop = fovy
|
||||
self.FovXleft = fovx
|
||||
self.znear = znear
|
||||
self.zfar = zfar
|
||||
self.world_view_transform = world_view_transform
|
||||
|
@ -15,7 +15,7 @@ from PIL import Image
|
||||
from typing import NamedTuple
|
||||
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
|
||||
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
|
||||
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
||||
from utils.graphics_utils import getWorld2View2, focal2sidefov, focal2fov, fov2focal
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
@ -27,8 +27,10 @@ class CameraInfo(NamedTuple):
|
||||
uid: int
|
||||
R: np.array
|
||||
T: np.array
|
||||
FovY: np.array
|
||||
FovX: np.array
|
||||
FovYtop: np.array
|
||||
FovYbottom: np.array
|
||||
FovXleft: np.array
|
||||
FovXright: np.array
|
||||
image: np.array
|
||||
image_path: str
|
||||
image_name: str
|
||||
@ -82,15 +84,23 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
||||
R = np.transpose(qvec2rotmat(extr.qvec))
|
||||
T = np.array(extr.tvec)
|
||||
|
||||
cx = intr.params[-2]
|
||||
cy = intr.params[-1]
|
||||
|
||||
if intr.model=="SIMPLE_PINHOLE":
|
||||
focal_length_x = intr.params[0]
|
||||
FovY = focal2fov(focal_length_x, height)
|
||||
FovX = focal2fov(focal_length_x, width)
|
||||
focal_length_y = intr.params[0]
|
||||
FovYtop = focal2sidefov(focal_length_y, -cy) # usually negative
|
||||
FovYbottom = focal2sidefov(focal_length_y, height - cy) # usually positive
|
||||
FovXleft = focal2sidefov(focal_length_x, -cx) # usually negative
|
||||
FovXright = focal2sidefov(focal_length_x, width - cx) # usually positive
|
||||
elif intr.model=="PINHOLE":
|
||||
focal_length_x = intr.params[0]
|
||||
focal_length_y = intr.params[1]
|
||||
FovY = focal2fov(focal_length_y, height)
|
||||
FovX = focal2fov(focal_length_x, width)
|
||||
FovYtop = focal2sidefov(focal_length_y, -cy) # usually negative
|
||||
FovYbottom = focal2sidefov(focal_length_y, height - cy) # usually positive
|
||||
FovXleft = focal2sidefov(focal_length_x, -cx) # usually negative
|
||||
FovXright = focal2sidefov(focal_length_x, width - cx) # usually positive
|
||||
else:
|
||||
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
||||
|
||||
@ -98,7 +108,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
||||
image_name = os.path.basename(image_path).split(".")[0]
|
||||
image = Image.open(image_path)
|
||||
|
||||
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
||||
cam_info = CameraInfo(uid=uid, R=R, T=T, FovYtop=FovYtop, FovYbottom=FovYbottom, FovXleft=FovXleft, FovXright=FovXright, image=image,
|
||||
image_path=image_path, image_name=image_name, width=width, height=height)
|
||||
cam_infos.append(cam_info)
|
||||
sys.stdout.write('\n')
|
||||
@ -181,7 +191,6 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
|
||||
|
||||
with open(os.path.join(path, transformsfile)) as json_file:
|
||||
contents = json.load(json_file)
|
||||
fovx = contents["camera_angle_x"]
|
||||
|
||||
frames = contents["frames"]
|
||||
for idx, frame in enumerate(frames):
|
||||
@ -209,13 +218,21 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
|
||||
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
||||
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
||||
|
||||
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
||||
FovY = fovy
|
||||
FovX = fovx
|
||||
cx = frame["cx"] if "cx" in frame else contents["cx"] if "cx" in contents else image.size[0] / 2
|
||||
cy = frame["cy"] if "cy" in frame else contents["cy"] if "cy" in contents else image.size[1] / 2
|
||||
fl_y = frame["fl_y"] if "fl_y" in frame else contents["fl_y"] if "fl_y" in contents else None
|
||||
fl_x = frame["fl_x"] if "fl_x" in frame else contents["fl_x"] if "fl_x" in contents else fl_y
|
||||
fovx = frame["camera_angle_x"] if "camera_angle_x" in frame else contents["camera_angle_x"] if "camera_angle_x" in contents else None
|
||||
fovy = frame["camera_angle_y"] if "camera_angle_y" in frame else contents["camera_angle_y"] if "camera_angle_y" in contents else focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
||||
# priority is given to ("fl_x", "cx") over "camera_angle_x" because it can be frame specific:
|
||||
fovYtop = focal2sidefov(fl_y, -cy) if fl_y else focal2sidefov(fov2focal(fovy, image.size[1]), -cy) # usually negative
|
||||
fovYbottom = focal2sidefov(fl_y, image.size[1] - cy) if fl_y else focal2sidefov(fov2focal(fovy, image.size[1]), image.size[1] - cy) # usually positive
|
||||
fovXleft = focal2sidefov(fl_x, -cx) if fl_x else focal2sidefov(fov2focal(fovx, image.size[0]), -cx) # usually negative
|
||||
fovXright = focal2sidefov(fl_x, image.size[0] - cx) if fl_x else focal2sidefov(fov2focal(fovx, image.size[0]), image.size[0] - cx) # usually positive
|
||||
|
||||
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
||||
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovYtop=fovYtop, FovYbottom=fovYbottom, FovXleft=fovXleft, FovXright=fovXright, image=image,
|
||||
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
||||
|
||||
|
||||
return cam_infos
|
||||
|
||||
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
|
||||
|
@ -13,6 +13,7 @@ from scene.cameras import Camera
|
||||
import numpy as np
|
||||
from utils.general_utils import PILtoTorch
|
||||
from utils.graphics_utils import fov2focal
|
||||
import math
|
||||
|
||||
WARNED = False
|
||||
|
||||
@ -47,7 +48,7 @@ def loadCam(args, id, cam_info, resolution_scale):
|
||||
loaded_mask = resized_image_rgb[3:4, ...]
|
||||
|
||||
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||
FovXleft=cam_info.FovXleft, FovXright=cam_info.FovXright, FovYtop=cam_info.FovYtop, FovYbottom=cam_info.FovYbottom,
|
||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
||||
|
||||
@ -76,7 +77,7 @@ def camera_to_JSON(id, camera : Camera):
|
||||
'height' : camera.height,
|
||||
'position': pos.tolist(),
|
||||
'rotation': serializable_array_2d,
|
||||
'fy' : fov2focal(camera.FovY, camera.height),
|
||||
'fx' : fov2focal(camera.FovX, camera.width)
|
||||
'fy' : camera.height / (2 * math.tan((camera.FovYbottom - camera.FovYtop) / 2)),
|
||||
'fx' : camera.width / (2 * math.tan((camera.FovXright - camera.FovXleft) / 2))
|
||||
}
|
||||
return camera_entry
|
||||
|
@ -48,30 +48,39 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
|
||||
Rt = np.linalg.inv(C2W)
|
||||
return np.float32(Rt)
|
||||
|
||||
def getProjectionMatrix(znear, zfar, fovX, fovY):
|
||||
tanHalfFovY = math.tan((fovY / 2))
|
||||
tanHalfFovX = math.tan((fovX / 2))
|
||||
def getProjectionMatrix(znear, zfar, fovXleft, fovXright, fovYtop, fovYbottom):
|
||||
tanHalfFovYtop = math.tan(fovYtop)
|
||||
tanHalfFovYbottom = math.tan(fovYbottom)
|
||||
tanHalfFovXleft = math.tan(fovXleft)
|
||||
tanHalfFovXright = math.tan(fovXright)
|
||||
|
||||
top = tanHalfFovY * znear
|
||||
bottom = -top
|
||||
right = tanHalfFovX * znear
|
||||
left = -right
|
||||
top = tanHalfFovYtop * znear
|
||||
bottom = tanHalfFovYbottom * znear
|
||||
left = tanHalfFovXleft * znear
|
||||
right = tanHalfFovXright * znear
|
||||
|
||||
P = torch.zeros(4, 4)
|
||||
|
||||
z_sign = 1.0
|
||||
|
||||
# note that my conventions are (fovXleft,fovYtop) negative and (fovXright,fovYbottom) positive
|
||||
P[0, 0] = 2.0 * znear / (right - left)
|
||||
P[1, 1] = 2.0 * znear / (top - bottom)
|
||||
P[0, 2] = (right + left) / (right - left)
|
||||
P[1, 2] = (top + bottom) / (top - bottom)
|
||||
P[1, 1] = 2.0 * znear / (bottom - top)
|
||||
P[0, 2] = -(right + left) / (right - left)
|
||||
P[1, 2] = -(top + bottom) / (bottom - top)
|
||||
P[3, 2] = z_sign
|
||||
P[2, 2] = z_sign * zfar / (zfar - znear)
|
||||
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
||||
return P
|
||||
|
||||
def fov2focal(fov, pixels):
|
||||
return pixels / (2 * math.tan(fov / 2))
|
||||
return sidefov2focal(fov / 2, pixels / 2)
|
||||
|
||||
def focal2fov(focal, pixels):
|
||||
return 2*math.atan(pixels/(2*focal))
|
||||
return 2 * focal2sidefov(focal, pixels / 2)
|
||||
|
||||
def sidefov2focal(sidefov, sidepixels):
|
||||
return sidepixels / math.tan(sidefov)
|
||||
|
||||
def focal2sidefov(focal, sidepixels):
|
||||
return math.atan(sidepixels / focal)
|
||||
|
Loading…
Reference in New Issue
Block a user