DreamCraft3D/threestudio/models/estimators.py
2023-12-15 17:44:44 +08:00

118 lines
4.2 KiB
Python

from typing import Callable, List, Optional, Tuple
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
from nerfacc.data_specs import RayIntervals
from nerfacc.estimators.base import AbstractEstimator
from nerfacc.pdf import importance_sampling, searchsorted
from nerfacc.volrend import render_transmittance_from_density
from torch import Tensor
class ImportanceEstimator(AbstractEstimator):
def __init__(
self,
) -> None:
super().__init__()
@torch.no_grad()
def sampling(
self,
prop_sigma_fns: List[Callable],
prop_samples: List[int],
num_samples: int,
# rendering options
n_rays: int,
near_plane: float,
far_plane: float,
sampling_type: Literal["uniform", "lindisp"] = "uniform",
# training options
stratified: bool = False,
requires_grad: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Sampling with CDFs from proposal networks.
Args:
prop_sigma_fns: Proposal network evaluate functions. It should be a list
of functions that take in samples {t_starts (n_rays, n_samples),
t_ends (n_rays, n_samples)} and returns the post-activation densities
(n_rays, n_samples).
prop_samples: Number of samples to draw from each proposal network. Should
be the same length as `prop_sigma_fns`.
num_samples: Number of samples to draw in the end.
n_rays: Number of rays.
near_plane: Near plane.
far_plane: Far plane.
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
"lindisp".
stratified: Whether to use stratified sampling. Default to `False`.
Returns:
A tuple of {Tensor, Tensor}:
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
"""
assert len(prop_sigma_fns) == len(prop_samples), (
"The number of proposal networks and the number of samples "
"should be the same."
)
cdfs = torch.cat(
[
torch.zeros((n_rays, 1), device=self.device),
torch.ones((n_rays, 1), device=self.device),
],
dim=-1,
)
intervals = RayIntervals(vals=cdfs)
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
intervals, _ = importance_sampling(
intervals, cdfs, level_samples, stratified
)
t_vals = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_starts = t_vals[..., :-1]
t_ends = t_vals[..., 1:]
with torch.set_grad_enabled(requires_grad):
sigmas = level_fn(t_starts, t_ends)
assert sigmas.shape == t_starts.shape
trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas)
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified)
t_vals_fine = _transform_stot(
sampling_type, intervals.vals, near_plane, far_plane
)
t_vals = torch.cat([t_vals, t_vals_fine], dim=-1)
t_vals, _ = torch.sort(t_vals, dim=-1)
t_starts_ = t_vals[..., :-1]
t_ends_ = t_vals[..., 1:]
return t_starts_, t_ends_
def _transform_stot(
transform_type: Literal["uniform", "lindisp"],
s_vals: torch.Tensor,
t_min: torch.Tensor,
t_max: torch.Tensor,
) -> torch.Tensor:
if transform_type == "uniform":
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
elif transform_type == "lindisp":
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
else:
raise ValueError(f"Unknown transform_type: {transform_type}")
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
return icontract_fn(s_vals)