mirror of
https://github.com/graphdeco-inria/gaussian-splatting
synced 2025-04-02 00:34:58 +00:00
add render depth and normal from gs
This commit is contained in:
parent
3e33ab3cda
commit
8b998dfba6
@ -14,8 +14,12 @@ import math
|
||||
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
||||
from scene.gaussian_model import GaussianModel
|
||||
from utils.sh_utils import eval_sh
|
||||
from utils.general_utils import build_rotation
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
||||
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None,
|
||||
return_depth=False, return_normal=False, return_opacity=False):
|
||||
"""
|
||||
渲染场景: 将高斯分布的点投影到2D屏幕上来生成渲染图像
|
||||
viewpoint_camera: 训练相机集合
|
||||
@ -111,7 +115,90 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
|
||||
|
||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||
# They will be excluded from value updates used in the splitting criteria.
|
||||
return {"render": rendered_image,
|
||||
# 返回一个字典,包含渲染的图像、屏幕空间坐标、可见性过滤器(根据半径判断是否可见)以及每个高斯分布在屏幕上的半径
|
||||
return_dict={"render": rendered_image,
|
||||
"viewspace_points": screenspace_points,
|
||||
"visibility_filter" : radii > 0,
|
||||
"radii": radii}
|
||||
|
||||
if return_depth:
|
||||
# 提取相机的世界视图变换矩阵的第3列(深度方向)的前3个元素和最后一个元素。这些值将用于计算深度信息
|
||||
projvect1 = viewpoint_camera.world_view_transform[:, 2][:3].detach()
|
||||
projvect2 = viewpoint_camera.world_view_transform[:, 2][-1].detach()
|
||||
# 计算每个3D点的深度值: means3D * projvect1.unsqueeze(0) 将 means3D (3D点坐标) 与 projvect1 (深度方向的前3个元素) 相乘,得到深度分量;
|
||||
# 深度分量求和,并加上 projvect2 (深度方向的最后一个元素),得到最终的深度值
|
||||
means3D_depth = (means3D * projvect1.unsqueeze(0)).sum(dim=-1, keepdim=True) + projvect2
|
||||
# 深度值复制成3个通道,与 colors_precomp 的尺寸匹配
|
||||
means3D_depth = means3D_depth.repeat(1, 3)
|
||||
render_depth, _ = rasterizer(
|
||||
means3D=means3D,
|
||||
means2D=means2D,
|
||||
shs=None,
|
||||
colors_precomp=means3D_depth,
|
||||
opacities=opacity,
|
||||
scales=scales,
|
||||
rotations=rotations,
|
||||
cov3D_precomp=cov3D_precomp)
|
||||
render_depth = render_depth.mean(dim=0) # 将批量维度上的深度值取平均,得到最终的深度图
|
||||
return_dict.update({'render_depth': render_depth})
|
||||
|
||||
# plt.figure(figsize=(8, 6))
|
||||
# plt.imshow(rendered_image.permute(1, 2, 0).detach().cpu().numpy(), cmap='viridis')
|
||||
# plt.colorbar()
|
||||
# plt.title('Rendered Normal Map')
|
||||
# plt.show()
|
||||
|
||||
# plt.figure(figsize=(8, 6))
|
||||
# plt.imshow(render_depth.detach().cpu().numpy(), cmap='viridis')
|
||||
# plt.colorbar()
|
||||
# plt.title('Rendered Depth Map')
|
||||
# plt.show()
|
||||
|
||||
|
||||
if return_normal:
|
||||
rotations_mat = build_rotation(rotations)
|
||||
scales = pc.get_scaling
|
||||
min_scales = torch.argmin(scales, dim=1)
|
||||
indices = torch.arange(min_scales.shape[0])
|
||||
normal = rotations_mat[indices, :, min_scales]
|
||||
|
||||
# convert normal direction to the camera; calculate the normal in the camera coordinate
|
||||
view_dir = means3D - viewpoint_camera.camera_center
|
||||
normal = normal * ((((view_dir * normal).sum(dim=-1) < 0) * 1 - 0.5) * 2)[..., None]
|
||||
|
||||
R_w2c = torch.tensor(viewpoint_camera.R.T).cuda().to(torch.float32)
|
||||
normal = (R_w2c @ normal.transpose(0, 1)).transpose(0, 1)
|
||||
|
||||
render_normal, _ = rasterizer(
|
||||
means3D=means3D,
|
||||
means2D=means2D,
|
||||
shs=None,
|
||||
colors_precomp=normal,
|
||||
opacities=opacity,
|
||||
scales=scales,
|
||||
rotations=rotations,
|
||||
cov3D_precomp=cov3D_precomp)
|
||||
render_normal = F.normalize(render_normal, dim=0)
|
||||
return_dict.update({'render_normal': render_normal})
|
||||
|
||||
# plt.figure(figsize=(8, 6))
|
||||
# plt.imshow(render_normal.permute(1, 2, 0).detach().cpu().numpy(), cmap='viridis')
|
||||
# plt.colorbar()
|
||||
# plt.title('Rendered Normal Map')
|
||||
# plt.show()
|
||||
|
||||
if return_opacity:
|
||||
density = torch.ones_like(means3D)
|
||||
|
||||
render_opacity, _ = rasterizer(
|
||||
means3D=means3D,
|
||||
means2D=means2D,
|
||||
shs=None,
|
||||
colors_precomp=density,
|
||||
opacities=opacity,
|
||||
scales=scales,
|
||||
rotations=rotations,
|
||||
cov3D_precomp=cov3D_precomp)
|
||||
return_dict.update({'render_opacity': render_opacity.mean(dim=0)})
|
||||
|
||||
return return_dict
|
Loading…
Reference in New Issue
Block a user