load speedup: refactored image loading

Images are now loaded on the target device as uint8s.
Then they are converted to the target data type (eg. fp32 or fp16).
This speeds up the loading time.

Also, users can opt to store the image as uint8 or as target data type.
This will further reduce memory usage.
This commit is contained in:
Stefan Saraev 2024-05-02 14:48:56 +03:00 committed by Matei Barbu
parent 18eb6d6a0c
commit b5a5f72eda
7 changed files with 41 additions and 17 deletions

View File

@ -196,6 +196,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
#### --data_dtype #### --data_dtype
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default. The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
#### --store_images_as_uint8
Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
</details> </details>
<br> <br>

View File

@ -54,6 +54,7 @@ class ModelParams(ParamGroup):
self._white_background = False self._white_background = False
self.data_device = "cuda" self.data_device = "cuda"
self.data_dtype = "float32" self.data_dtype = "float32"
self.store_images_as_uint8 = False
self.eval = False self.eval = False
super().__init__(parser, "Loading Parameters", sentinel) super().__init__(parser, "Loading Parameters", sentinel)

View File

@ -94,7 +94,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
# after rasterization, we convert the resulting image to the target dtype # after rasterization, we convert the resulting image to the target dtype
# The rasterizer expects parameters as float32, so the result is also float32. # The rasterizer expects parameters as float32, so the result is also float32.
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) rendered_image = rendered_image.to(viewpoint_camera.data_dtype)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible. # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria. # They will be excluded from value updates used in the splitting criteria.

View File

@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
with torch.no_grad(): with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=dtype, device="cuda") background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
if not skip_train: if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
@ -63,6 +63,4 @@ if __name__ == "__main__":
# Initialize system state (RNG) # Initialize system state (RNG)
safe_state(args.quiet) safe_state(args.quiet)
dtype = torch.float32 render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)

View File

@ -17,7 +17,8 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module): class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid, image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
store_images_as_uint8=True,
): ):
super(Camera, self).__init__() super(Camera, self).__init__()
@ -37,14 +38,18 @@ class Camera(nn.Module):
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda") self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) self.store_images_as_uint8 = store_images_as_uint8
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None: self._original_image = image.to(self.data_device)
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) self._gt_alpha_mask = gt_alpha_mask
else: if self._gt_alpha_mask is not None:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
if not store_images_as_uint8:
self._original_image = self.convert_image(self._original_image)
self.image_width = self._original_image.shape[2]
self.image_height = self._original_image.shape[1]
self.zfar = 100.0 self.zfar = 100.0
self.znear = 0.01 self.znear = 0.01
@ -57,6 +62,23 @@ class Camera(nn.Module):
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3] self.camera_center = self.world_view_transform.inverse()[3, :3]
def convert_image(self, image):
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
gt_alpha_mask = self._gt_alpha_mask
if gt_alpha_mask is not None:
gt_alpha_mask = gt_alpha_mask / 255.0
image *= gt_alpha_mask.to(self.data_dtype)
return image
@property
def original_image(self):
if self.store_images_as_uint8:
return self.convert_image(self._original_image)
else:
return self._original_image
class MiniCam: class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width self.image_width = width

View File

@ -52,7 +52,8 @@ def loadCam(args, id, cam_info, resolution_scale):
FoVx=cam_info.FovX, FoVy=cam_info.FovY, FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask, image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device, image_name=cam_info.image_name, uid=id, data_device=args.data_device,
data_dtype=get_data_dtype(args.data_dtype)) data_dtype=get_data_dtype(args.data_dtype),
store_images_as_uint8=args.store_images_as_uint8)
def cameraList_from_camInfos(cam_infos, resolution_scale, args): def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = [] camera_list = []

View File

@ -20,7 +20,7 @@ def inverse_sigmoid(x):
def PILtoTorch(pil_image, resolution): def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution) resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 resized_image = torch.from_numpy(np.array(resized_image_PIL))# / 255.0
if len(resized_image.shape) == 3: if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1) return resized_image.permute(2, 0, 1)
else: else: