DreamCraft3D/threestudio/systems/zero123.py
2023-12-15 17:44:44 +08:00

390 lines
13 KiB
Python

import os
import random
import shutil
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from PIL import Image, ImageDraw
from torchmetrics import PearsonCorrCoef
import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *
@threestudio.register("zero123-system")
class Zero123(BaseLift3DSystem):
@dataclass
class Config(BaseLift3DSystem.Config):
freq: dict = field(default_factory=dict)
refinement: bool = False
ambient_ratio_min: float = 0.5
cfg: Config
def configure(self):
# create geometry, material, background, renderer
super().configure()
def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
render_out = self.renderer(**batch)
return {
**render_out,
}
def on_fit_start(self) -> None:
super().on_fit_start()
# no prompt processor
self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)
# visualize all training images
all_images = self.trainer.datamodule.train_dataloader().dataset.get_all_images()
self.save_image_grid(
"all_training_images.png",
[
{"type": "rgb", "img": image, "kwargs": {"data_format": "HWC"}}
for image in all_images
],
name="on_fit_start",
step=self.true_global_step,
)
self.pearson = PearsonCorrCoef().to(self.device)
def training_substep(self, batch, batch_idx, guidance: str):
"""
Args:
guidance: one of "ref" (reference image supervision), "zero123"
"""
if guidance == "ref":
# bg_color = torch.rand_like(batch['rays_o'])
ambient_ratio = 1.0
shading = "diffuse"
batch["shading"] = shading
elif guidance == "zero123":
batch = batch["random_camera"]
ambient_ratio = (
self.cfg.ambient_ratio_min
+ (1 - self.cfg.ambient_ratio_min) * random.random()
)
batch["bg_color"] = None
batch["ambient_ratio"] = ambient_ratio
out = self(batch)
loss_prefix = f"loss_{guidance}_"
loss_terms = {}
def set_loss(name, value):
loss_terms[f"{loss_prefix}{name}"] = value
guidance_eval = (
guidance == "zero123"
and self.cfg.freq.guidance_eval > 0
and self.true_global_step % self.cfg.freq.guidance_eval == 0
)
if guidance == "ref":
gt_mask = batch["mask"]
gt_rgb = batch["rgb"]
# color loss
gt_rgb = gt_rgb * gt_mask.float() + out["comp_rgb_bg"] * (
1 - gt_mask.float()
)
set_loss("rgb", F.mse_loss(gt_rgb, out["comp_rgb"]))
# mask loss
set_loss("mask", F.mse_loss(gt_mask.float(), out["opacity"]))
# depth loss
if self.C(self.cfg.loss.lambda_depth) > 0:
valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)].unsqueeze(1)
valid_pred_depth = out["depth"][gt_mask].unsqueeze(1)
with torch.no_grad():
A = torch.cat(
[valid_gt_depth, torch.ones_like(valid_gt_depth)], dim=-1
) # [B, 2]
X = torch.linalg.lstsq(A, valid_pred_depth).solution # [2, 1]
valid_gt_depth = A @ X # [B, 1]
set_loss("depth", F.mse_loss(valid_gt_depth, valid_pred_depth))
# relative depth loss
if self.C(self.cfg.loss.lambda_depth_rel) > 0:
valid_gt_depth = batch["ref_depth"][gt_mask.squeeze(-1)] # [B,]
valid_pred_depth = out["depth"][gt_mask] # [B,]
set_loss(
"depth_rel", 1 - self.pearson(valid_pred_depth, valid_gt_depth)
)
# normal loss
if self.C(self.cfg.loss.lambda_normal) > 0:
valid_gt_normal = (
1 - 2 * batch["ref_normal"][gt_mask.squeeze(-1)]
) # [B, 3]
valid_pred_normal = (
2 * out["comp_normal"][gt_mask.squeeze(-1)] - 1
) # [B, 3]
set_loss(
"normal",
1 - F.cosine_similarity(valid_pred_normal, valid_gt_normal).mean(),
)
elif guidance == "zero123":
# zero123
guidance_out = self.guidance(
out["comp_rgb"],
**batch,
rgb_as_latents=False,
guidance_eval=guidance_eval,
)
# claforte: TODO: rename the loss_terms keys
set_loss("sds", guidance_out["loss_sds"])
if self.C(self.cfg.loss.lambda_normal_smooth) > 0:
if "comp_normal" not in out:
raise ValueError(
"comp_normal is required for 2D normal smooth loss, no comp_normal is found in the output."
)
normal = out["comp_normal"]
set_loss(
"normal_smooth",
(normal[:, 1:, :, :] - normal[:, :-1, :, :]).square().mean()
+ (normal[:, :, 1:, :] - normal[:, :, :-1, :]).square().mean(),
)
if self.C(self.cfg.loss.lambda_3d_normal_smooth) > 0:
if "normal" not in out:
raise ValueError(
"Normal is required for normal smooth loss, no normal is found in the output."
)
if "normal_perturb" not in out:
raise ValueError(
"normal_perturb is required for normal smooth loss, no normal_perturb is found in the output."
)
normals = out["normal"]
normals_perturb = out["normal_perturb"]
set_loss("3d_normal_smooth", (normals - normals_perturb).abs().mean())
if not self.cfg.refinement:
if self.C(self.cfg.loss.lambda_orient) > 0:
if "normal" not in out:
raise ValueError(
"Normal is required for orientation loss, no normal is found in the output."
)
set_loss(
"orient",
(
out["weights"].detach()
* dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
).sum()
/ (out["opacity"] > 0).sum(),
)
if guidance != "ref" and self.C(self.cfg.loss.lambda_sparsity) > 0:
set_loss("sparsity", (out["opacity"] ** 2 + 0.01).sqrt().mean())
if self.C(self.cfg.loss.lambda_opaque) > 0:
opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
set_loss(
"opaque", binary_cross_entropy(opacity_clamped, opacity_clamped)
)
else:
if self.C(self.cfg.loss.lambda_normal_consistency) > 0:
set_loss("normal_consistency", out["mesh"].normal_consistency())
if self.C(self.cfg.loss.lambda_laplacian_smoothness) > 0:
set_loss("laplacian_smoothness", out["mesh"].laplacian())
loss = 0.0
for name, value in loss_terms.items():
self.log(f"train/{name}", value)
if name.startswith(loss_prefix):
loss_weighted = value * self.C(
self.cfg.loss[name.replace(loss_prefix, "lambda_")]
)
self.log(f"train/{name}_w", loss_weighted)
loss += loss_weighted
for name, value in self.cfg.loss.items():
self.log(f"train_params/{name}", self.C(value))
self.log(f"train/loss_{guidance}", loss)
if guidance_eval:
self.guidance_evaluation_save(
out["comp_rgb"].detach()[: guidance_out["eval"]["bs"]],
guidance_out["eval"],
)
return {"loss": loss}
def training_step(self, batch, batch_idx):
if self.cfg.freq.get("ref_or_zero123", "accumulate") == "accumulate":
do_ref = True
do_zero123 = True
elif self.cfg.freq.get("ref_or_zero123", "accumulate") == "alternate":
do_ref = (
self.true_global_step < self.cfg.freq.ref_only_steps
or self.true_global_step % self.cfg.freq.n_ref == 0
)
do_zero123 = not do_ref
total_loss = 0.0
if do_zero123:
out = self.training_substep(batch, batch_idx, guidance="zero123")
total_loss += out["loss"]
if do_ref:
out = self.training_substep(batch, batch_idx, guidance="ref")
total_loss += out["loss"]
self.log("train/loss", total_loss, prog_bar=True)
# sch = self.lr_schedulers()
# sch.step()
return {"loss": total_loss}
def validation_step(self, batch, batch_idx):
out = self(batch)
self.save_image_grid(
f"it{self.true_global_step}-val/{batch['index'][0]}.png",
(
[
{
"type": "rgb",
"img": batch["rgb"][0],
"kwargs": {"data_format": "HWC"},
}
]
if "rgb" in batch
else []
)
+ [
{
"type": "rgb",
"img": out["comp_rgb"][0],
"kwargs": {"data_format": "HWC"},
},
]
+ (
[
{
"type": "rgb",
"img": out["comp_normal"][0],
"kwargs": {"data_format": "HWC", "data_range": (0, 1)},
}
]
if "comp_normal" in out
else []
)
+ (
[
{
"type": "grayscale",
"img": out["depth"][0],
"kwargs": {},
}
]
if "depth" in out
else []
)
+ [
{
"type": "grayscale",
"img": out["opacity"][0, :, :, 0],
"kwargs": {"cmap": None, "data_range": (0, 1)},
},
],
# claforte: TODO: don't hardcode the frame numbers to record... read them from cfg instead.
name=f"validation_step_batchidx_{batch_idx}"
if batch_idx in [0, 7, 15, 23, 29]
else None,
step=self.true_global_step,
)
def on_validation_epoch_end(self):
filestem = f"it{self.true_global_step}-val"
self.save_img_sequence(
filestem,
filestem,
"(\d+)\.png",
save_format="mp4",
fps=30,
name="validation_epoch_end",
step=self.true_global_step,
)
shutil.rmtree(
os.path.join(self.get_save_dir(), f"it{self.true_global_step}-val")
)
def test_step(self, batch, batch_idx):
out = self(batch)
self.save_image_grid(
f"it{self.true_global_step}-test/{batch['index'][0]}.png",
(
[
{
"type": "rgb",
"img": batch["rgb"][0],
"kwargs": {"data_format": "HWC"},
}
]
if "rgb" in batch
else []
)
+ [
{
"type": "rgb",
"img": out["comp_rgb"][0],
"kwargs": {"data_format": "HWC"},
},
]
+ (
[
{
"type": "rgb",
"img": out["comp_normal"][0],
"kwargs": {"data_format": "HWC", "data_range": (0, 1)},
}
]
if "comp_normal" in out
else []
)
+ (
[
{
"type": "grayscale",
"img": out["depth"][0],
"kwargs": {},
}
]
if "depth" in out
else []
)
+ [
{
"type": "grayscale",
"img": out["opacity"][0, :, :, 0],
"kwargs": {"cmap": None, "data_range": (0, 1)},
},
],
name="test_step",
step=self.true_global_step,
)
def on_test_epoch_end(self):
self.save_img_sequence(
f"it{self.true_global_step}-test",
f"it{self.true_global_step}-test",
"(\d+)\.png",
save_format="mp4",
fps=30,
name="test",
step=self.true_global_step,
)
shutil.rmtree(
os.path.join(self.get_save_dir(), f"it{self.true_global_step}-test")
)