mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
load speedup: refactored image loading
Images are now loaded on the target device as uint8s. Then they are converted to the target data type (eg. fp32 or fp16). This speeds up the loading time. Also, users can opt to store the image as uint8 or as target data type. This will further reduce memory usage.
This commit is contained in:
parent
18eb6d6a0c
commit
b5a5f72eda
@ -196,6 +196,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
|
||||
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
|
||||
#### --data_dtype
|
||||
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
|
||||
#### --store_images_as_uint8
|
||||
Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
|
||||
|
||||
</details>
|
||||
<br>
|
||||
|
@ -54,6 +54,7 @@ class ModelParams(ParamGroup):
|
||||
self._white_background = False
|
||||
self.data_device = "cuda"
|
||||
self.data_dtype = "float32"
|
||||
self.store_images_as_uint8 = False
|
||||
self.eval = False
|
||||
super().__init__(parser, "Loading Parameters", sentinel)
|
||||
|
||||
|
@ -94,7 +94,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
|
||||
# after rasterization, we convert the resulting image to the target dtype
|
||||
# The rasterizer expects parameters as float32, so the result is also float32.
|
||||
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype)
|
||||
rendered_image = rendered_image.to(viewpoint_camera.data_dtype)
|
||||
|
||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||
# They will be excluded from value updates used in the splitting criteria.
|
||||
|
10
render.py
10
render.py
@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
|
||||
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
||||
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
||||
|
||||
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32):
|
||||
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
|
||||
with torch.no_grad():
|
||||
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype)
|
||||
gaussians = GaussianModel(dataset.sh_degree)
|
||||
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
||||
|
||||
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
||||
background = torch.tensor(bg_color, dtype=dtype, device="cuda")
|
||||
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
||||
|
||||
if not skip_train:
|
||||
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
|
||||
@ -63,6 +63,4 @@ if __name__ == "__main__":
|
||||
# Initialize system state (RNG)
|
||||
safe_state(args.quiet)
|
||||
|
||||
dtype = torch.float32
|
||||
|
||||
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)
|
||||
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
|
@ -17,7 +17,8 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
||||
class Camera(nn.Module):
|
||||
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
||||
image_name, uid,
|
||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32
|
||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
|
||||
store_images_as_uint8=True,
|
||||
):
|
||||
super(Camera, self).__init__()
|
||||
|
||||
@ -37,14 +38,18 @@ class Camera(nn.Module):
|
||||
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
||||
self.data_device = torch.device("cuda")
|
||||
|
||||
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device)
|
||||
self.image_width = self.original_image.shape[2]
|
||||
self.image_height = self.original_image.shape[1]
|
||||
self.store_images_as_uint8 = store_images_as_uint8
|
||||
|
||||
if gt_alpha_mask is not None:
|
||||
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device)
|
||||
else:
|
||||
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
||||
self._original_image = image.to(self.data_device)
|
||||
self._gt_alpha_mask = gt_alpha_mask
|
||||
if self._gt_alpha_mask is not None:
|
||||
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
|
||||
|
||||
if not store_images_as_uint8:
|
||||
self._original_image = self.convert_image(self._original_image)
|
||||
|
||||
self.image_width = self._original_image.shape[2]
|
||||
self.image_height = self._original_image.shape[1]
|
||||
|
||||
self.zfar = 100.0
|
||||
self.znear = 0.01
|
||||
@ -57,6 +62,23 @@ class Camera(nn.Module):
|
||||
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
||||
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
||||
|
||||
def convert_image(self, image):
|
||||
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
|
||||
gt_alpha_mask = self._gt_alpha_mask
|
||||
|
||||
if gt_alpha_mask is not None:
|
||||
gt_alpha_mask = gt_alpha_mask / 255.0
|
||||
image *= gt_alpha_mask.to(self.data_dtype)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def original_image(self):
|
||||
if self.store_images_as_uint8:
|
||||
return self.convert_image(self._original_image)
|
||||
else:
|
||||
return self._original_image
|
||||
|
||||
class MiniCam:
|
||||
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
||||
self.image_width = width
|
||||
|
@ -52,7 +52,8 @@ def loadCam(args, id, cam_info, resolution_scale):
|
||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
|
||||
data_dtype=get_data_dtype(args.data_dtype))
|
||||
data_dtype=get_data_dtype(args.data_dtype),
|
||||
store_images_as_uint8=args.store_images_as_uint8)
|
||||
|
||||
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
||||
camera_list = []
|
||||
|
@ -20,7 +20,7 @@ def inverse_sigmoid(x):
|
||||
|
||||
def PILtoTorch(pil_image, resolution):
|
||||
resized_image_PIL = pil_image.resize(resolution)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL))# / 255.0
|
||||
if len(resized_image.shape) == 3:
|
||||
return resized_image.permute(2, 0, 1)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user