add render depth and normal from gs

This commit is contained in:
lzhi 2024-05-22 00:53:59 +08:00
parent 3e33ab3cda
commit 8b998dfba6

View File

@ -14,8 +14,12 @@ import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh
from utils.general_utils import build_rotation
import torch.nn.functional as F
import matplotlib.pyplot as plt
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None,
return_depth=False, return_normal=False, return_opacity=False):
"""
渲染场景 将高斯分布的点投影到2D屏幕上来生成渲染图像
viewpoint_camera: 训练相机集合
@ -111,7 +115,90 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
# 返回一个字典,包含渲染的图像、屏幕空间坐标、可见性过滤器(根据半径判断是否可见)以及每个高斯分布在屏幕上的半径
return_dict={"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii}
if return_depth:
# 提取相机的世界视图变换矩阵的第3列(深度方向)的前3个元素和最后一个元素。这些值将用于计算深度信息
projvect1 = viewpoint_camera.world_view_transform[:, 2][:3].detach()
projvect2 = viewpoint_camera.world_view_transform[:, 2][-1].detach()
# 计算每个3D点的深度值 means3D * projvect1.unsqueeze(0) 将 means3D (3D点坐标) 与 projvect1 (深度方向的前3个元素) 相乘,得到深度分量;
#  深度分量求和,并加上 projvect2 (深度方向的最后一个元素),得到最终的深度值
means3D_depth = (means3D * projvect1.unsqueeze(0)).sum(dim=-1, keepdim=True) + projvect2
# 深度值复制成3个通道,与 colors_precomp 的尺寸匹配
means3D_depth = means3D_depth.repeat(1, 3)
render_depth, _ = rasterizer(
means3D=means3D,
means2D=means2D,
shs=None,
colors_precomp=means3D_depth,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp)
render_depth = render_depth.mean(dim=0) # 将批量维度上的深度值取平均,得到最终的深度图
return_dict.update({'render_depth': render_depth})
# plt.figure(figsize=(8, 6))
# plt.imshow(rendered_image.permute(1, 2, 0).detach().cpu().numpy(), cmap='viridis')
# plt.colorbar()
# plt.title('Rendered Normal Map')
# plt.show()
# plt.figure(figsize=(8, 6))
# plt.imshow(render_depth.detach().cpu().numpy(), cmap='viridis')
# plt.colorbar()
# plt.title('Rendered Depth Map')
# plt.show()
if return_normal:
rotations_mat = build_rotation(rotations)
scales = pc.get_scaling
min_scales = torch.argmin(scales, dim=1)
indices = torch.arange(min_scales.shape[0])
normal = rotations_mat[indices, :, min_scales]
# convert normal direction to the camera; calculate the normal in the camera coordinate
view_dir = means3D - viewpoint_camera.camera_center
normal = normal * ((((view_dir * normal).sum(dim=-1) < 0) * 1 - 0.5) * 2)[..., None]
R_w2c = torch.tensor(viewpoint_camera.R.T).cuda().to(torch.float32)
normal = (R_w2c @ normal.transpose(0, 1)).transpose(0, 1)
render_normal, _ = rasterizer(
means3D=means3D,
means2D=means2D,
shs=None,
colors_precomp=normal,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp)
render_normal = F.normalize(render_normal, dim=0)
return_dict.update({'render_normal': render_normal})
# plt.figure(figsize=(8, 6))
# plt.imshow(render_normal.permute(1, 2, 0).detach().cpu().numpy(), cmap='viridis')
# plt.colorbar()
# plt.title('Rendered Normal Map')
# plt.show()
if return_opacity:
density = torch.ones_like(means3D)
render_opacity, _ = rasterizer(
means3D=means3D,
means2D=means2D,
shs=None,
colors_precomp=density,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp)
return_dict.update({'render_opacity': render_opacity.mean(dim=0)})
return return_dict