diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 98c73e8..5e250aa 100644 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -29,11 +29,11 @@ class CameraInfo(NamedTuple): T: np.array FovY: np.array FovX: np.array - image: np.array image_path: str image_name: str width: int height: int + is_RGBA: bool class SceneInfo(NamedTuple): point_cloud: BasicPointCloud @@ -96,10 +96,9 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): image_path = os.path.join(images_folder, os.path.basename(extr.name)) image_name = os.path.basename(image_path).split(".")[0] - image = Image.open(image_path) - cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=width, height=height) + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, + image_path=image_path, image_name=image_name, width=width, height=height, is_RGBA=False) cam_infos.append(cam_info) sys.stdout.write('\n') return cam_infos @@ -196,20 +195,12 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= image_name = Path(cam_name).stem image = Image.open(image_path) - im_data = np.array(image.convert("RGBA")) - - bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) - - norm_data = im_data / 255.0 - arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) - image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") - fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) FovY = fovy FovX = fovx - cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, - image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, + image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], is_RGBA=True)) return cam_infos diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0a..57cb707 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -10,6 +10,7 @@ # from scene.cameras import Camera +from PIL import Image import numpy as np from utils.general_utils import PILtoTorch from utils.graphics_utils import fov2focal @@ -17,8 +18,17 @@ from utils.graphics_utils import fov2focal WARNED = False def loadCam(args, id, cam_info, resolution_scale): - orig_w, orig_h = cam_info.image.size + image = Image.open(cam_info.image_path) + if cam_info.is_RGBA: + im_data = np.array(image.convert("RGBA")) + bg = np.array([1,1,1]) if args.white_background else np.array([0, 0, 0]) + + norm_data = im_data / 255.0 + arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) + image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + orig_w, orig_h = image.size + if args.resolution in [1, 2, 4, 8]: resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) else: # should be a type that converts to float @@ -38,7 +48,7 @@ def loadCam(args, id, cam_info, resolution_scale): scale = float(global_down) * float(resolution_scale) resolution = (int(orig_w / scale), int(orig_h / scale)) - resized_image_rgb = PILtoTorch(cam_info.image, resolution) + resized_image_rgb = PILtoTorch(image, resolution) gt_image = resized_image_rgb[:3, ...] loaded_mask = None