Merge branch 'release' into develop

This commit is contained in:
bkerbl 2023-07-13 13:19:34 +02:00
commit 7d8035ad10
4 changed files with 15 additions and 5 deletions

View File

@ -115,6 +115,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Add this flag to use a MipNeRF360-style training/test split for evaluation. Add this flag to use a MipNeRF360-style training/test split for evaluation.
#### --resolution / -r #### --resolution / -r
Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
#### --data_device
Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training.
#### --white_background / -w #### --white_background / -w
Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
#### --sh_degree #### --sh_degree

View File

@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
self._images = "images" self._images = "images"
self._resolution = -1 self._resolution = -1
self._white_background = False self._white_background = False
self.data_device = "cuda"
self.eval = False self.eval = False
super().__init__(parser, "Loading Parameters", sentinel) super().__init__(parser, "Loading Parameters", sentinel)

View File

@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module): class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid, image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
): ):
super(Camera, self).__init__() super(Camera, self).__init__()
@ -29,14 +29,21 @@ class Camera(nn.Module):
self.FoVy = FoVy self.FoVy = FoVy
self.image_name = image_name self.image_name = image_name
self.original_image = image.clamp(0.0, 1.0).cuda() try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2] self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1] self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None: if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.cuda() self.original_image *= gt_alpha_mask.to(self.data_device)
else: else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda") self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self.zfar = 100.0 self.zfar = 100.0
self.znear = 0.01 self.znear = 0.01

View File

@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY, FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask, image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id) image_name=cam_info.image_name, uid=id, data_device=args.data_device)
def cameraList_from_camInfos(cam_infos, resolution_scale, args): def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = [] camera_list = []