diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 7ca4cf3..34e3e95 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -81,7 +81,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, colors_precomp = override_color # Rasterize visible Gaussians to image, obtain their radii (on screen). - rendered_image, radii = rasterizer( + rendered_image, radii, depth = rasterizer( means3D = means3D, means2D = means2D, shs = shs, @@ -96,4 +96,5 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, return {"render": rendered_image, "viewspace_points": screenspace_points, "visibility_filter" : radii > 0, - "radii": radii} + "radii": radii, + "depth": depth} diff --git a/render.py b/render.py index fc6b82d..4ad81bf 100644 --- a/render.py +++ b/render.py @@ -24,15 +24,22 @@ from gaussian_renderer import GaussianModel def render_set(model_path, name, iteration, views, gaussians, pipeline, background): render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth") makedirs(render_path, exist_ok=True) makedirs(gts_path, exist_ok=True) + makedirs(depth_path, exist_ok=True) for idx, view in enumerate(tqdm(views, desc="Rendering progress")): - rendering = render(view, gaussians, pipeline, background)["render"] + results = render(view, gaussians, pipeline, background) + rendering = results["render"] + depth = results["depth"] + depth = depth / (depth.max() + 1e-5) + gt = view.original_image[0:3, :, :] torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png")) def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): with torch.no_grad():