mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-25 13:26:47 +00:00
Merge branch 'release' into develop
This commit is contained in:
commit
7d8035ad10
@ -115,6 +115,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
|
|||||||
Add this flag to use a MipNeRF360-style training/test split for evaluation.
|
Add this flag to use a MipNeRF360-style training/test split for evaluation.
|
||||||
#### --resolution / -r
|
#### --resolution / -r
|
||||||
Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
|
Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
|
||||||
|
#### --data_device
|
||||||
|
Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training.
|
||||||
#### --white_background / -w
|
#### --white_background / -w
|
||||||
Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
|
Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
|
||||||
#### --sh_degree
|
#### --sh_degree
|
||||||
|
@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
|
|||||||
self._images = "images"
|
self._images = "images"
|
||||||
self._resolution = -1
|
self._resolution = -1
|
||||||
self._white_background = False
|
self._white_background = False
|
||||||
|
self.data_device = "cuda"
|
||||||
self.eval = False
|
self.eval = False
|
||||||
super().__init__(parser, "Loading Parameters", sentinel)
|
super().__init__(parser, "Loading Parameters", sentinel)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
|||||||
class Camera(nn.Module):
|
class Camera(nn.Module):
|
||||||
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
||||||
image_name, uid,
|
image_name, uid,
|
||||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0
|
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
||||||
):
|
):
|
||||||
super(Camera, self).__init__()
|
super(Camera, self).__init__()
|
||||||
|
|
||||||
@ -29,14 +29,21 @@ class Camera(nn.Module):
|
|||||||
self.FoVy = FoVy
|
self.FoVy = FoVy
|
||||||
self.image_name = image_name
|
self.image_name = image_name
|
||||||
|
|
||||||
self.original_image = image.clamp(0.0, 1.0).cuda()
|
try:
|
||||||
|
self.data_device = torch.device(data_device)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
||||||
|
self.data_device = torch.device("cuda")
|
||||||
|
|
||||||
|
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
||||||
self.image_width = self.original_image.shape[2]
|
self.image_width = self.original_image.shape[2]
|
||||||
self.image_height = self.original_image.shape[1]
|
self.image_height = self.original_image.shape[1]
|
||||||
|
|
||||||
if gt_alpha_mask is not None:
|
if gt_alpha_mask is not None:
|
||||||
self.original_image *= gt_alpha_mask.cuda()
|
self.original_image *= gt_alpha_mask.to(self.data_device)
|
||||||
else:
|
else:
|
||||||
self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
|
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
||||||
|
|
||||||
self.zfar = 100.0
|
self.zfar = 100.0
|
||||||
self.znear = 0.01
|
self.znear = 0.01
|
||||||
|
@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
|
|||||||
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
||||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||||
image_name=cam_info.image_name, uid=id)
|
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
||||||
|
|
||||||
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
||||||
camera_list = []
|
camera_list = []
|
||||||
|
Loading…
Reference in New Issue
Block a user