mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-06 14:05:02 +00:00
Merge 39fb001ef0
into 8a70a8cd6f
This commit is contained in:
commit
99a3ce85e3
@ -206,6 +206,10 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
|
||||
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
|
||||
#### --percent_dense
|
||||
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
|
||||
#### --data_dtype
|
||||
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
|
||||
#### --store_images_as_uint8
|
||||
Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
|
||||
|
||||
</details>
|
||||
<br>
|
||||
|
@ -53,6 +53,8 @@ class ModelParams(ParamGroup):
|
||||
self._resolution = -1
|
||||
self._white_background = False
|
||||
self.data_device = "cuda"
|
||||
self.data_dtype = "float32"
|
||||
self.store_images_as_uint8 = False
|
||||
self.eval = False
|
||||
super().__init__(parser, "Loading Parameters", sentinel)
|
||||
|
||||
|
@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
rotations = rotations,
|
||||
cov3D_precomp = cov3D_precomp)
|
||||
|
||||
# after rasterization, we convert the resulting image to the target dtype
|
||||
# The rasterizer expects parameters as float32, so the result is also float32.
|
||||
rendered_image = rendered_image.to(viewpoint_camera.data_dtype)
|
||||
|
||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||
# They will be excluded from value updates used in the splitting criteria.
|
||||
return {"render": rendered_image,
|
||||
|
@ -17,7 +17,8 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
||||
class Camera(nn.Module):
|
||||
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
||||
image_name, uid,
|
||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
||||
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
|
||||
store_images_as_uint8=True,
|
||||
):
|
||||
super(Camera, self).__init__()
|
||||
|
||||
@ -28,6 +29,7 @@ class Camera(nn.Module):
|
||||
self.FoVx = FoVx
|
||||
self.FoVy = FoVy
|
||||
self.image_name = image_name
|
||||
self.data_dtype = data_dtype
|
||||
|
||||
try:
|
||||
self.data_device = torch.device(data_device)
|
||||
@ -36,14 +38,18 @@ class Camera(nn.Module):
|
||||
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
||||
self.data_device = torch.device("cuda")
|
||||
|
||||
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
||||
self.image_width = self.original_image.shape[2]
|
||||
self.image_height = self.original_image.shape[1]
|
||||
self.store_images_as_uint8 = store_images_as_uint8
|
||||
|
||||
if gt_alpha_mask is not None:
|
||||
self.original_image *= gt_alpha_mask.to(self.data_device)
|
||||
else:
|
||||
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
||||
self._original_image = image.to(self.data_device)
|
||||
self._gt_alpha_mask = gt_alpha_mask
|
||||
if self._gt_alpha_mask is not None:
|
||||
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)
|
||||
|
||||
if not store_images_as_uint8:
|
||||
self._original_image = self.convert_image(self._original_image)
|
||||
|
||||
self.image_width = self._original_image.shape[2]
|
||||
self.image_height = self._original_image.shape[1]
|
||||
|
||||
self.zfar = 100.0
|
||||
self.znear = 0.01
|
||||
@ -56,6 +62,23 @@ class Camera(nn.Module):
|
||||
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
||||
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
||||
|
||||
def convert_image(self, image):
|
||||
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
|
||||
gt_alpha_mask = self._gt_alpha_mask
|
||||
|
||||
if gt_alpha_mask is not None:
|
||||
gt_alpha_mask = gt_alpha_mask / 255.0
|
||||
image *= gt_alpha_mask.to(self.data_dtype)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def original_image(self):
|
||||
if self.store_images_as_uint8:
|
||||
return self.convert_image(self._original_image)
|
||||
else:
|
||||
return self._original_image
|
||||
|
||||
class MiniCam:
|
||||
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
||||
self.image_width = width
|
||||
|
@ -11,7 +11,7 @@
|
||||
|
||||
from scene.cameras import Camera
|
||||
import numpy as np
|
||||
from utils.general_utils import PILtoTorch
|
||||
from utils.general_utils import PILtoTorch, get_data_dtype
|
||||
from utils.graphics_utils import fov2focal
|
||||
|
||||
WARNED = False
|
||||
@ -49,7 +49,9 @@ def loadCam(args, id, cam_info, resolution_scale):
|
||||
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
|
||||
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
|
||||
image=gt_image, gt_alpha_mask=loaded_mask,
|
||||
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
|
||||
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
|
||||
data_dtype=get_data_dtype(args.data_dtype),
|
||||
store_images_as_uint8=args.store_images_as_uint8)
|
||||
|
||||
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
|
||||
camera_list = []
|
||||
|
@ -20,7 +20,7 @@ def inverse_sigmoid(x):
|
||||
|
||||
def PILtoTorch(pil_image, resolution):
|
||||
resized_image_PIL = pil_image.resize(resolution)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL))
|
||||
if len(resized_image.shape) == 3:
|
||||
return resized_image.permute(2, 0, 1)
|
||||
else:
|
||||
@ -131,3 +131,12 @@ def safe_state(silent):
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.set_device(torch.device("cuda:0"))
|
||||
|
||||
def get_data_dtype(dtype):
|
||||
if dtype == "float32":
|
||||
return torch.float32
|
||||
elif dtype == "float64":
|
||||
return torch.float64
|
||||
elif dtype == "float16":
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
Loading…
Reference in New Issue
Block a user