mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-04 18:15:11 +00:00
118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
from typing import Callable, List, Optional, Tuple
|
|
|
|
try:
|
|
from typing import Literal
|
|
except ImportError:
|
|
from typing_extensions import Literal
|
|
|
|
import torch
|
|
from nerfacc.data_specs import RayIntervals
|
|
from nerfacc.estimators.base import AbstractEstimator
|
|
from nerfacc.pdf import importance_sampling, searchsorted
|
|
from nerfacc.volrend import render_transmittance_from_density
|
|
from torch import Tensor
|
|
|
|
|
|
class ImportanceEstimator(AbstractEstimator):
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
@torch.no_grad()
|
|
def sampling(
|
|
self,
|
|
prop_sigma_fns: List[Callable],
|
|
prop_samples: List[int],
|
|
num_samples: int,
|
|
# rendering options
|
|
n_rays: int,
|
|
near_plane: float,
|
|
far_plane: float,
|
|
sampling_type: Literal["uniform", "lindisp"] = "uniform",
|
|
# training options
|
|
stratified: bool = False,
|
|
requires_grad: bool = False,
|
|
) -> Tuple[Tensor, Tensor]:
|
|
"""Sampling with CDFs from proposal networks.
|
|
|
|
Args:
|
|
prop_sigma_fns: Proposal network evaluate functions. It should be a list
|
|
of functions that take in samples {t_starts (n_rays, n_samples),
|
|
t_ends (n_rays, n_samples)} and returns the post-activation densities
|
|
(n_rays, n_samples).
|
|
prop_samples: Number of samples to draw from each proposal network. Should
|
|
be the same length as `prop_sigma_fns`.
|
|
num_samples: Number of samples to draw in the end.
|
|
n_rays: Number of rays.
|
|
near_plane: Near plane.
|
|
far_plane: Far plane.
|
|
sampling_type: Sampling type. Either "uniform" or "lindisp". Default to
|
|
"lindisp".
|
|
stratified: Whether to use stratified sampling. Default to `False`.
|
|
|
|
Returns:
|
|
A tuple of {Tensor, Tensor}:
|
|
|
|
- **t_starts**: The starts of the samples. Shape (n_rays, num_samples).
|
|
- **t_ends**: The ends of the samples. Shape (n_rays, num_samples).
|
|
|
|
"""
|
|
assert len(prop_sigma_fns) == len(prop_samples), (
|
|
"The number of proposal networks and the number of samples "
|
|
"should be the same."
|
|
)
|
|
cdfs = torch.cat(
|
|
[
|
|
torch.zeros((n_rays, 1), device=self.device),
|
|
torch.ones((n_rays, 1), device=self.device),
|
|
],
|
|
dim=-1,
|
|
)
|
|
intervals = RayIntervals(vals=cdfs)
|
|
|
|
for level_fn, level_samples in zip(prop_sigma_fns, prop_samples):
|
|
intervals, _ = importance_sampling(
|
|
intervals, cdfs, level_samples, stratified
|
|
)
|
|
t_vals = _transform_stot(
|
|
sampling_type, intervals.vals, near_plane, far_plane
|
|
)
|
|
t_starts = t_vals[..., :-1]
|
|
t_ends = t_vals[..., 1:]
|
|
|
|
with torch.set_grad_enabled(requires_grad):
|
|
sigmas = level_fn(t_starts, t_ends)
|
|
assert sigmas.shape == t_starts.shape
|
|
trans, _ = render_transmittance_from_density(t_starts, t_ends, sigmas)
|
|
cdfs = 1.0 - torch.cat([trans, torch.zeros_like(trans[:, :1])], dim=-1)
|
|
|
|
intervals, _ = importance_sampling(intervals, cdfs, num_samples, stratified)
|
|
t_vals_fine = _transform_stot(
|
|
sampling_type, intervals.vals, near_plane, far_plane
|
|
)
|
|
|
|
t_vals = torch.cat([t_vals, t_vals_fine], dim=-1)
|
|
t_vals, _ = torch.sort(t_vals, dim=-1)
|
|
|
|
t_starts_ = t_vals[..., :-1]
|
|
t_ends_ = t_vals[..., 1:]
|
|
|
|
return t_starts_, t_ends_
|
|
|
|
|
|
def _transform_stot(
|
|
transform_type: Literal["uniform", "lindisp"],
|
|
s_vals: torch.Tensor,
|
|
t_min: torch.Tensor,
|
|
t_max: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
if transform_type == "uniform":
|
|
_contract_fn, _icontract_fn = lambda x: x, lambda x: x
|
|
elif transform_type == "lindisp":
|
|
_contract_fn, _icontract_fn = lambda x: 1 / x, lambda x: 1 / x
|
|
else:
|
|
raise ValueError(f"Unknown transform_type: {transform_type}")
|
|
s_min, s_max = _contract_fn(t_min), _contract_fn(t_max)
|
|
icontract_fn = lambda s: _icontract_fn(s * s_max + (1 - s) * s_min)
|
|
return icontract_fn(s_vals) |