diff --git a/README.md b/README.md index 08e7587..172382d 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,8 @@ python train.py -s Add this flag to use a MipNeRF360-style training/test split for evaluation. #### --resolution / -r Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** + #### --data_device + Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training. #### --white_background / -w Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. #### --sh_degree diff --git a/arguments/__init__.py b/arguments/__init__.py index 4ab86f4..eba1dba 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -52,6 +52,7 @@ class ModelParams(ParamGroup): self._images = "images" self._resolution = -1 self._white_background = False + self.data_device = "cuda" self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/scene/cameras.py b/scene/cameras.py index b57d351..abf6e52 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0 + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" ): super(Camera, self).__init__() @@ -29,14 +29,21 @@ class Camera(nn.Module): self.FoVy = FoVy self.image_name = image_name - self.original_image = image.clamp(0.0, 1.0).cuda() + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.original_image = image.clamp(0.0, 1.0).to(self.data_device) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.cuda() + self.original_image *= gt_alpha_mask.to(self.data_device) else: - self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda") + self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) self.zfar = 100.0 self.znear = 0.01 diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 21d762e..1a54d0a 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale): return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id) + image_name=cam_info.image_name, uid=id, data_device=args.data_device) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = []