Lazy loading images to reduce CPU memory overhead

This commit is contained in:
liuxu 2023-07-25 13:21:15 +08:00
parent 39f54b758c
commit ec2d4093ba
2 changed files with 17 additions and 16 deletions

View File

@ -29,11 +29,11 @@ class CameraInfo(NamedTuple):
T: np.array
FovY: np.array
FovX: np.array
image: np.array
image_path: str
image_name: str
width: int
height: int
is_RGBA: bool
class SceneInfo(NamedTuple):
point_cloud: BasicPointCloud
@ -96,10 +96,9 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
image_path = os.path.join(images_folder, os.path.basename(extr.name))
image_name = os.path.basename(image_path).split(".")[0]
image = Image.open(image_path)
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=width, height=height)
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX,
image_path=image_path, image_name=image_name, width=width, height=height, is_RGBA=False)
cam_infos.append(cam_info)
sys.stdout.write('\n')
return cam_infos
@ -196,20 +195,12 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
image_name = Path(cam_name).stem
image = Image.open(image_path)
im_data = np.array(image.convert("RGBA"))
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
norm_data = im_data / 255.0
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
FovY = fovy
FovX = fovx
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1], is_RGBA=True))
return cam_infos

View File

@ -10,6 +10,7 @@
#
from scene.cameras import Camera
from PIL import Image
import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal
@ -17,7 +18,16 @@ from utils.graphics_utils import fov2focal
WARNED = False
def loadCam(args, id, cam_info, resolution_scale):
orig_w, orig_h = cam_info.image.size
image = Image.open(cam_info.image_path)
if cam_info.is_RGBA:
im_data = np.array(image.convert("RGBA"))
bg = np.array([1,1,1]) if args.white_background else np.array([0, 0, 0])
norm_data = im_data / 255.0
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
orig_w, orig_h = image.size
if args.resolution in [1, 2, 4, 8]:
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
@ -38,7 +48,7 @@ def loadCam(args, id, cam_info, resolution_scale):
scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale))
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
resized_image_rgb = PILtoTorch(image, resolution)
gt_image = resized_image_rgb[:3, ...]
loaded_mask = None