mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-25 05:16:33 +00:00
Lazy loading images to reduce CPU memory overhead
This commit is contained in:
parent
39f54b758c
commit
ec2d4093ba
@ -29,11 +29,11 @@ class CameraInfo(NamedTuple):
|
|||||||
T: np.array
|
T: np.array
|
||||||
FovY: np.array
|
FovY: np.array
|
||||||
FovX: np.array
|
FovX: np.array
|
||||||
image: np.array
|
|
||||||
image_path: str
|
image_path: str
|
||||||
image_name: str
|
image_name: str
|
||||||
width: int
|
width: int
|
||||||
height: int
|
height: int
|
||||||
|
is_RGBA: bool
|
||||||
|
|
||||||
class SceneInfo(NamedTuple):
|
class SceneInfo(NamedTuple):
|
||||||
point_cloud: BasicPointCloud
|
point_cloud: BasicPointCloud
|
||||||
@ -96,10 +96,9 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
|
|||||||
|
|
||||||
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
||||||
image_name = os.path.basename(image_path).split(".")[0]
|
image_name = os.path.basename(image_path).split(".")[0]
|
||||||
image = Image.open(image_path)
|
|
||||||
|
|
||||||
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX,
|
||||||
image_path=image_path, image_name=image_name, width=width, height=height)
|
image_path=image_path, image_name=image_name, width=width, height=height, is_RGBA=False)
|
||||||
cam_infos.append(cam_info)
|
cam_infos.append(cam_info)
|
||||||
sys.stdout.write('\n')
|
sys.stdout.write('\n')
|
||||||
return cam_infos
|
return cam_infos
|
||||||
@ -196,20 +195,12 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
|
|||||||
image_name = Path(cam_name).stem
|
image_name = Path(cam_name).stem
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
|
|
||||||
im_data = np.array(image.convert("RGBA"))
|
|
||||||
|
|
||||||
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
|
||||||
|
|
||||||
norm_data = im_data / 255.0
|
|
||||||
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
|
||||||
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
|
||||||
|
|
||||||
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
||||||
FovY = fovy
|
FovY = fovy
|
||||||
FovX = fovx
|
FovX = fovx
|
||||||
|
|
||||||
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,
|
||||||
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], is_RGBA=True))
|
||||||
|
|
||||||
return cam_infos
|
return cam_infos
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from scene.cameras import Camera
|
from scene.cameras import Camera
|
||||||
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from utils.general_utils import PILtoTorch
|
from utils.general_utils import PILtoTorch
|
||||||
from utils.graphics_utils import fov2focal
|
from utils.graphics_utils import fov2focal
|
||||||
@ -17,8 +18,17 @@ from utils.graphics_utils import fov2focal
|
|||||||
WARNED = False
|
WARNED = False
|
||||||
|
|
||||||
def loadCam(args, id, cam_info, resolution_scale):
|
def loadCam(args, id, cam_info, resolution_scale):
|
||||||
orig_w, orig_h = cam_info.image.size
|
image = Image.open(cam_info.image_path)
|
||||||
|
if cam_info.is_RGBA:
|
||||||
|
im_data = np.array(image.convert("RGBA"))
|
||||||
|
|
||||||
|
bg = np.array([1,1,1]) if args.white_background else np.array([0, 0, 0])
|
||||||
|
|
||||||
|
norm_data = im_data / 255.0
|
||||||
|
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
||||||
|
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
|
||||||
if args.resolution in [1, 2, 4, 8]:
|
if args.resolution in [1, 2, 4, 8]:
|
||||||
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
|
||||||
else: # should be a type that converts to float
|
else: # should be a type that converts to float
|
||||||
@ -38,7 +48,7 @@ def loadCam(args, id, cam_info, resolution_scale):
|
|||||||
scale = float(global_down) * float(resolution_scale)
|
scale = float(global_down) * float(resolution_scale)
|
||||||
resolution = (int(orig_w / scale), int(orig_h / scale))
|
resolution = (int(orig_w / scale), int(orig_h / scale))
|
||||||
|
|
||||||
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
|
resized_image_rgb = PILtoTorch(image, resolution)
|
||||||
|
|
||||||
gt_image = resized_image_rgb[:3, ...]
|
gt_image = resized_image_rgb[:3, ...]
|
||||||
loaded_mask = None
|
loaded_mask = None
|
||||||
|
Loading…
Reference in New Issue
Block a user