mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2024-12-04 18:15:11 +00:00
111 lines
3.6 KiB
Python
Executable File
111 lines
3.6 KiB
Python
Executable File
import abc
|
|
from typing import List, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from IPython.display import clear_output
|
|
from scipy import interpolate
|
|
|
|
|
|
class GuideModel(torch.nn.Module, abc.ABC):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
@abc.abstractmethod
|
|
def preprocess(self, x_img):
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def compute_loss(self, inp):
|
|
pass
|
|
|
|
|
|
class Guider(torch.nn.Module):
|
|
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
|
"""Apply classifier guidance
|
|
|
|
Specify a guidance scale as either a scalar
|
|
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
|
[(0, 10), (0.5, 20), (1, 50)]
|
|
"""
|
|
super().__init__()
|
|
self.sampler = sampler
|
|
self.index = 0
|
|
self.show = verbose
|
|
self.guide_model = guide_model
|
|
self.history = []
|
|
|
|
if isinstance(scale, (Tuple, List)):
|
|
times = np.array([x[0] for x in scale])
|
|
values = np.array([x[1] for x in scale])
|
|
self.scale_schedule = {"times": times, "values": values}
|
|
else:
|
|
self.scale_schedule = float(scale)
|
|
|
|
self.ddim_timesteps = sampler.ddim_timesteps
|
|
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
|
|
|
def get_scales(self):
|
|
if isinstance(self.scale_schedule, float):
|
|
return len(self.ddim_timesteps) * [self.scale_schedule]
|
|
|
|
interpolater = interpolate.interp1d(
|
|
self.scale_schedule["times"], self.scale_schedule["values"]
|
|
)
|
|
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
|
|
return interpolater(fractional_steps)
|
|
|
|
def modify_score(self, model, e_t, x, t, c):
|
|
# TODO look up index by t
|
|
scale = self.get_scales()[self.index]
|
|
|
|
if scale == 0:
|
|
return e_t
|
|
|
|
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
|
with torch.enable_grad():
|
|
x_in = x.detach().requires_grad_(True)
|
|
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
|
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
|
|
|
|
inp = self.guide_model.preprocess(x_img)
|
|
loss = self.guide_model.compute_loss(inp)
|
|
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
|
correction = grads * scale
|
|
|
|
if self.show:
|
|
clear_output(wait=True)
|
|
print(
|
|
loss.item(),
|
|
scale,
|
|
correction.abs().max().item(),
|
|
e_t.abs().max().item(),
|
|
)
|
|
self.history.append(
|
|
[
|
|
loss.item(),
|
|
scale,
|
|
correction.min().item(),
|
|
correction.max().item(),
|
|
]
|
|
)
|
|
plt.imshow(
|
|
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
|
|
)
|
|
plt.axis("off")
|
|
plt.show()
|
|
plt.imshow(correction[0][0].detach().cpu())
|
|
plt.axis("off")
|
|
plt.show()
|
|
|
|
e_t_mod = e_t - sqrt_1ma * correction
|
|
if self.show:
|
|
fig, axs = plt.subplots(1, 3)
|
|
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
|
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
|
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
|
plt.show()
|
|
self.index += 1
|
|
return e_t_mod
|