DreamCraft3D/extern/ldm_zero123/guidance.py
2023-12-15 17:44:44 +08:00

111 lines
3.6 KiB
Python
Executable File

import abc
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import clear_output
from scipy import interpolate
class GuideModel(torch.nn.Module, abc.ABC):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def preprocess(self, x_img):
pass
@abc.abstractmethod
def compute_loss(self, inp):
pass
class Guider(torch.nn.Module):
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
"""Apply classifier guidance
Specify a guidance scale as either a scalar
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
[(0, 10), (0.5, 20), (1, 50)]
"""
super().__init__()
self.sampler = sampler
self.index = 0
self.show = verbose
self.guide_model = guide_model
self.history = []
if isinstance(scale, (Tuple, List)):
times = np.array([x[0] for x in scale])
values = np.array([x[1] for x in scale])
self.scale_schedule = {"times": times, "values": values}
else:
self.scale_schedule = float(scale)
self.ddim_timesteps = sampler.ddim_timesteps
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
def get_scales(self):
if isinstance(self.scale_schedule, float):
return len(self.ddim_timesteps) * [self.scale_schedule]
interpolater = interpolate.interp1d(
self.scale_schedule["times"], self.scale_schedule["values"]
)
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
return interpolater(fractional_steps)
def modify_score(self, model, e_t, x, t, c):
# TODO look up index by t
scale = self.get_scales()[self.index]
if scale == 0:
return e_t
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
inp = self.guide_model.preprocess(x_img)
loss = self.guide_model.compute_loss(inp)
grads = torch.autograd.grad(loss.sum(), x_in)[0]
correction = grads * scale
if self.show:
clear_output(wait=True)
print(
loss.item(),
scale,
correction.abs().max().item(),
e_t.abs().max().item(),
)
self.history.append(
[
loss.item(),
scale,
correction.min().item(),
correction.max().item(),
]
)
plt.imshow(
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
)
plt.axis("off")
plt.show()
plt.imshow(correction[0][0].detach().cpu())
plt.axis("off")
plt.show()
e_t_mod = e_t - sqrt_1ma * correction
if self.show:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
plt.show()
self.index += 1
return e_t_mod