mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2024-11-22 08:18:17 +00:00
Add depth visualization
This commit is contained in:
parent
9f45a4f4e7
commit
13c6602e13
@ -81,7 +81,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
colors_precomp = override_color
|
||||
|
||||
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
||||
rendered_image, radii = rasterizer(
|
||||
rendered_image, radii, depth = rasterizer(
|
||||
means3D = means3D,
|
||||
means2D = means2D,
|
||||
shs = shs,
|
||||
@ -96,4 +96,5 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
return {"render": rendered_image,
|
||||
"viewspace_points": screenspace_points,
|
||||
"visibility_filter" : radii > 0,
|
||||
"radii": radii}
|
||||
"radii": radii,
|
||||
"depth": depth}
|
||||
|
@ -24,15 +24,22 @@ from gaussian_renderer import GaussianModel
|
||||
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
|
||||
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
||||
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
|
||||
depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")
|
||||
|
||||
makedirs(render_path, exist_ok=True)
|
||||
makedirs(gts_path, exist_ok=True)
|
||||
makedirs(depth_path, exist_ok=True)
|
||||
|
||||
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
||||
rendering = render(view, gaussians, pipeline, background)["render"]
|
||||
results = render(view, gaussians, pipeline, background)
|
||||
rendering = results["render"]
|
||||
depth = results["depth"]
|
||||
depth = depth / (depth.max() + 1e-5)
|
||||
|
||||
gt = view.original_image[0:3, :, :]
|
||||
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
||||
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
|
||||
torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png"))
|
||||
|
||||
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user