Provide --data_device option to put data on CPU to save VRAM for training (#14)

* Provide --data_on_cpu option to save VRAM for training

when there are many training images such as in large scene, most of the VRAM are used to store training data, use --data_on_cpu  can help reduce VRAM and make it possible to train on GPU with less VRAM

* Fix data_on_cpu  effect on default mask

* --data_on_cpu to --data_device

* update readme

* format warning infos
This commit is contained in:
Pythonix Huang 2023-07-13 02:30:45 +08:00 committed by GitHub
parent cebfc978f3
commit 989320fdf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 5 deletions

View File

@ -165,6 +165,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 <iterations>``` by default.
#### --quiet
Flag to omit any text written to standard out pipe.
#### --data_device
Specify where to put the data on,```cuda``` by default, recommend use ```cpu``` if training on large scale/resolution dataset, will save a lot of VRAM required to train, but slightly slower the training
</details>
<br>

View File

@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
self._images = "images"
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

View File

@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
):
super(Camera, self).__init__()
@ -29,14 +29,21 @@ class Camera(nn.Module):
self.FoVy = FoVy
self.image_name = image_name
self.original_image = image.clamp(0.0, 1.0).cuda()
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.cuda()
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self.zfar = 100.0
self.znear = 0.01

View File

@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id)
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []