mirror of
https://github.com/deepseek-ai/Janus
synced 2024-12-28 14:52:12 +00:00
715 lines
22 KiB
Python
715 lines
22 KiB
Python
# Copyright (c) 2023-2024 DeepSeek.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
# this software and associated documentation files (the "Software"), to deal in
|
|
# the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
# modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torchvision
|
|
import torchvision.utils
|
|
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
|
|
|
|
|
class ImageHead(nn.Module):
|
|
|
|
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
cfg = (
|
|
AttrDict(
|
|
norm_type="layernorm",
|
|
is_exp_norm=False,
|
|
sequence_parallel=False,
|
|
use_userbuffer=False,
|
|
norm_eps=1e-5,
|
|
norm_bias=True,
|
|
gradient_accumulation_fusion=True,
|
|
use_fp32_head_weight=False,
|
|
)
|
|
+ gpt_cfg
|
|
)
|
|
group = PG.tensor_parallel_group()
|
|
assert cfg.norm_type in [
|
|
"layernorm",
|
|
"rmsnorm",
|
|
], f"Norm type:{cfg.norm_type} not supported"
|
|
if cfg.norm_type == "rmsnorm":
|
|
self.norm = DropoutAddRMSNorm(
|
|
cfg.n_embed,
|
|
prenorm=False,
|
|
eps=cfg.norm_eps,
|
|
is_exp_norm=cfg.is_exp_norm,
|
|
sequence_parallel=cfg.sequence_parallel,
|
|
)
|
|
else:
|
|
self.norm = DropoutAddLayerNorm(
|
|
cfg.n_embed,
|
|
prenorm=False,
|
|
eps=cfg.norm_eps,
|
|
is_exp_norm=cfg.is_exp_norm,
|
|
sequence_parallel=cfg.sequence_parallel,
|
|
bias=cfg.norm_bias,
|
|
)
|
|
|
|
multiple_of = 256
|
|
if decoder_cfg.in_channels % multiple_of != 0:
|
|
warnings.warn(
|
|
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
|
|
)
|
|
|
|
dtype = default_dtype = torch.get_default_dtype()
|
|
if cfg.use_fp32_head_weight:
|
|
dtype = torch.float32
|
|
print(
|
|
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
|
|
end="",
|
|
flush=True,
|
|
)
|
|
torch.set_default_dtype(dtype)
|
|
self.head = ColumnParallelLinear(
|
|
cfg.n_embed,
|
|
decoder_cfg.in_channels,
|
|
bias=True,
|
|
group=group,
|
|
sequence_parallel=cfg.sequence_parallel,
|
|
use_userbuffer=cfg.use_userbuffer,
|
|
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
|
|
use_fp32_output=False,
|
|
)
|
|
torch.set_default_dtype(default_dtype)
|
|
|
|
self.use_fp32_head_weight = cfg.use_fp32_head_weight
|
|
|
|
def forward(
|
|
self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs
|
|
):
|
|
residual = None
|
|
if isinstance(input_args, tuple):
|
|
x, residual = input_args
|
|
else:
|
|
x = input_args
|
|
|
|
x = self.norm(x, residual)
|
|
|
|
if self.use_fp32_head_weight:
|
|
assert (
|
|
self.head.weight.dtype == torch.float32
|
|
), f"head.weight is {self.head.weight.dtype}"
|
|
x = x.float()
|
|
|
|
if images_split_mask is None:
|
|
logits = self.head(x)
|
|
else:
|
|
bs, n_images = images_split_mask.shape[:2]
|
|
n_embed = x.shape[-1]
|
|
|
|
images_embed = torch.masked_select(
|
|
x.unsqueeze(1), images_split_mask.unsqueeze(-1)
|
|
)
|
|
images_embed = images_embed.view((bs * n_images, -1, n_embed))
|
|
logits = self.head(images_embed)
|
|
|
|
return logits
|
|
|
|
|
|
class GlobalResponseNorm(nn.Module):
|
|
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
|
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
|
|
|
def forward(self, x):
|
|
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
|
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
|
|
|
return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1)
|
|
|
|
|
|
class Downsample2D(nn.Module):
|
|
"""A 2D downsampling layer with an optional convolution.
|
|
|
|
Parameters:
|
|
channels (`int`):
|
|
number of channels in the inputs and outputs.
|
|
use_conv (`bool`, default `False`):
|
|
option to use a convolution.
|
|
out_channels (`int`, optional):
|
|
number of output channels. Defaults to `channels`.
|
|
padding (`int`, default `1`):
|
|
padding for the convolution.
|
|
name (`str`, default `conv`):
|
|
name of the downsampling 2D layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
use_conv: bool = False,
|
|
out_channels: Optional[int] = None,
|
|
padding: int = 1,
|
|
name: str = "conv",
|
|
kernel_size=3,
|
|
stride=2,
|
|
norm_type=None,
|
|
eps=None,
|
|
elementwise_affine=None,
|
|
bias=True,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.padding = padding
|
|
self.name = name
|
|
|
|
if norm_type == "ln_norm":
|
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
|
elif norm_type == "rms_norm":
|
|
self.norm = RMSNorm(channels, eps)
|
|
elif norm_type is None:
|
|
self.norm = None
|
|
else:
|
|
raise ValueError(f"unknown norm_type: {norm_type}")
|
|
|
|
if use_conv:
|
|
conv = nn.Conv2d(
|
|
self.channels,
|
|
self.out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias,
|
|
)
|
|
else:
|
|
assert self.channels == self.out_channels
|
|
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if name == "conv":
|
|
self.Conv2d_0 = conv
|
|
self.conv = conv
|
|
elif name == "Conv2d_0":
|
|
self.conv = conv
|
|
else:
|
|
self.conv = conv
|
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
if self.norm is not None:
|
|
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
|
0, 3, 1, 2
|
|
)
|
|
|
|
if self.use_conv and self.padding == 0:
|
|
pad = (0, 1, 0, 1)
|
|
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
|
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
hidden_states = self.conv(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class Upsample2D(nn.Module):
|
|
"""A 2D upsampling layer with an optional convolution.
|
|
|
|
Parameters:
|
|
channels (`int`):
|
|
number of channels in the inputs and outputs.
|
|
use_conv (`bool`, default `False`):
|
|
option to use a convolution.
|
|
use_conv_transpose (`bool`, default `False`):
|
|
option to use a convolution transpose.
|
|
out_channels (`int`, optional):
|
|
number of output channels. Defaults to `channels`.
|
|
name (`str`, default `conv`):
|
|
name of the upsampling 2D layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
use_conv: bool = False,
|
|
use_conv_transpose: bool = False,
|
|
out_channels: Optional[int] = None,
|
|
name: str = "conv",
|
|
kernel_size: Optional[int] = None,
|
|
padding=1,
|
|
stride=2,
|
|
norm_type=None,
|
|
eps=None,
|
|
elementwise_affine=None,
|
|
bias=True,
|
|
interpolate=True,
|
|
):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.out_channels = out_channels or channels
|
|
self.use_conv = use_conv
|
|
self.use_conv_transpose = use_conv_transpose
|
|
self.name = name
|
|
self.interpolate = interpolate
|
|
self.stride = stride
|
|
|
|
if norm_type == "ln_norm":
|
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
|
elif norm_type == "rms_norm":
|
|
self.norm = RMSNorm(channels, eps)
|
|
elif norm_type is None:
|
|
self.norm = None
|
|
else:
|
|
raise ValueError(f"unknown norm_type: {norm_type}")
|
|
|
|
conv = None
|
|
if use_conv_transpose:
|
|
if kernel_size is None:
|
|
kernel_size = 4
|
|
conv = nn.ConvTranspose2d(
|
|
channels,
|
|
self.out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias,
|
|
)
|
|
elif use_conv:
|
|
if kernel_size is None:
|
|
kernel_size = 3
|
|
conv = nn.Conv2d(
|
|
self.channels,
|
|
self.out_channels,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
bias=bias,
|
|
)
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if name == "conv":
|
|
self.conv = conv
|
|
else:
|
|
self.Conv2d_0 = conv
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output_size: Optional[int] = None,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
|
|
assert hidden_states.shape[1] == self.channels
|
|
|
|
if self.norm is not None:
|
|
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
|
0, 3, 1, 2
|
|
)
|
|
|
|
if self.use_conv_transpose:
|
|
return self.conv(hidden_states)
|
|
|
|
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
|
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
|
# https://github.com/pytorch/pytorch/issues/86679
|
|
dtype = hidden_states.dtype
|
|
if dtype == torch.bfloat16:
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
|
|
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
|
if hidden_states.shape[0] >= 64:
|
|
hidden_states = hidden_states.contiguous()
|
|
|
|
# if `output_size` is passed we force the interpolation output
|
|
# size and do not make use of `scale_factor=2`
|
|
if self.interpolate:
|
|
if output_size is None:
|
|
hidden_states = F.interpolate(
|
|
hidden_states, scale_factor=self.stride, mode="nearest"
|
|
)
|
|
else:
|
|
hidden_states = F.interpolate(
|
|
hidden_states, size=output_size, mode="nearest"
|
|
)
|
|
|
|
# If the input is bfloat16, we cast back to bfloat16
|
|
if dtype == torch.bfloat16:
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
|
if self.use_conv:
|
|
if self.name == "conv":
|
|
hidden_states = self.conv(hidden_states)
|
|
else:
|
|
hidden_states = self.Conv2d_0(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ConvNextBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
norm_eps,
|
|
elementwise_affine,
|
|
use_bias,
|
|
hidden_dropout,
|
|
hidden_size,
|
|
res_ffn_factor: int = 4,
|
|
):
|
|
super().__init__()
|
|
self.depthwise = nn.Conv2d(
|
|
channels,
|
|
channels,
|
|
kernel_size=7,
|
|
padding=3,
|
|
groups=channels,
|
|
bias=use_bias,
|
|
)
|
|
self.norm = RMSNorm(channels, norm_eps)
|
|
self.channelwise_linear_1 = nn.Linear(
|
|
channels, int(channels * res_ffn_factor), bias=use_bias
|
|
)
|
|
self.channelwise_act = nn.GELU()
|
|
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
|
|
self.channelwise_linear_2 = nn.Linear(
|
|
int(channels * res_ffn_factor), channels, bias=use_bias
|
|
)
|
|
self.channelwise_dropout = nn.Dropout(hidden_dropout)
|
|
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
|
|
|
|
def forward(self, x, cond_embeds):
|
|
x_res = x
|
|
|
|
x = self.depthwise(x)
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = self.norm(x)
|
|
x = self.channelwise_linear_1(x)
|
|
x = self.channelwise_act(x)
|
|
x = self.channelwise_norm(x)
|
|
x = self.channelwise_linear_2(x)
|
|
x = self.channelwise_dropout(x)
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
x = x + x_res
|
|
|
|
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
|
|
# x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
|
x = torch.addcmul(
|
|
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
class Patchify(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
block_out_channels,
|
|
patch_size,
|
|
bias,
|
|
elementwise_affine,
|
|
eps,
|
|
kernel_size=None,
|
|
):
|
|
super().__init__()
|
|
if kernel_size is None:
|
|
kernel_size = patch_size
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels,
|
|
block_out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=patch_size,
|
|
bias=bias,
|
|
)
|
|
self.norm = RMSNorm(block_out_channels, eps)
|
|
|
|
def forward(self, x):
|
|
embeddings = self.patch_conv(x)
|
|
embeddings = embeddings.permute(0, 2, 3, 1)
|
|
embeddings = self.norm(embeddings)
|
|
embeddings = embeddings.permute(0, 3, 1, 2)
|
|
return embeddings
|
|
|
|
|
|
class Unpatchify(nn.Module):
|
|
def __init__(
|
|
self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps
|
|
):
|
|
super().__init__()
|
|
self.norm = RMSNorm(in_channels, eps)
|
|
self.unpatch_conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels * patch_size * patch_size,
|
|
kernel_size=1,
|
|
bias=bias,
|
|
)
|
|
self.pixel_shuffle = nn.PixelShuffle(patch_size)
|
|
self.patch_size = patch_size
|
|
|
|
def forward(self, x):
|
|
# [b, c, h, w]
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = self.norm(x)
|
|
x = x.permute(0, 3, 1, 2)
|
|
x = self.unpatch_conv(x)
|
|
x = self.pixel_shuffle(x)
|
|
return x
|
|
|
|
|
|
class UVitBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
channels,
|
|
out_channels,
|
|
num_res_blocks,
|
|
stride,
|
|
hidden_size,
|
|
hidden_dropout,
|
|
elementwise_affine,
|
|
norm_eps,
|
|
use_bias,
|
|
downsample: bool,
|
|
upsample: bool,
|
|
res_ffn_factor: int = 4,
|
|
seq_len=None,
|
|
concat_input=False,
|
|
original_input_channels=None,
|
|
use_zero=True,
|
|
norm_type="RMS",
|
|
):
|
|
super().__init__()
|
|
|
|
self.res_blocks = nn.ModuleList()
|
|
for i in range(num_res_blocks):
|
|
conv_block = ConvNextBlock(
|
|
channels,
|
|
norm_eps,
|
|
elementwise_affine,
|
|
use_bias,
|
|
hidden_dropout,
|
|
hidden_size,
|
|
res_ffn_factor=res_ffn_factor,
|
|
)
|
|
|
|
self.res_blocks.append(conv_block)
|
|
|
|
if downsample:
|
|
self.downsample = Downsample2D(
|
|
channels=channels,
|
|
out_channels=out_channels,
|
|
use_conv=True,
|
|
name="Conv2d_0",
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=stride,
|
|
norm_type="rms_norm",
|
|
eps=norm_eps,
|
|
elementwise_affine=elementwise_affine,
|
|
bias=use_bias,
|
|
)
|
|
else:
|
|
self.downsample = None
|
|
|
|
if upsample:
|
|
self.upsample = Upsample2D(
|
|
channels=channels,
|
|
out_channels=out_channels,
|
|
use_conv_transpose=False,
|
|
use_conv=True,
|
|
kernel_size=3,
|
|
padding=1,
|
|
stride=stride,
|
|
name="conv",
|
|
norm_type="rms_norm",
|
|
eps=norm_eps,
|
|
elementwise_affine=elementwise_affine,
|
|
bias=use_bias,
|
|
interpolate=True,
|
|
)
|
|
else:
|
|
self.upsample = None
|
|
|
|
def forward(self, x, emb, recompute=False):
|
|
for res_block in self.res_blocks:
|
|
x = res_block(x, emb)
|
|
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
|
|
if self.upsample is not None:
|
|
x = self.upsample(x)
|
|
|
|
return x
|
|
|
|
|
|
class ShallowUViTEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_channels=3,
|
|
stride=4,
|
|
kernel_size=7,
|
|
padding=None,
|
|
block_out_channels=(768,),
|
|
layers_in_middle=2,
|
|
hidden_size=2048,
|
|
elementwise_affine=True,
|
|
use_bias=True,
|
|
norm_eps=1e-6,
|
|
dropout=0.0,
|
|
use_mid_block=True,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.time_proj = Timesteps(
|
|
block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
|
|
)
|
|
self.time_embed = TimestepEmbedding(
|
|
block_out_channels[0], hidden_size, sample_proj_bias=use_bias
|
|
)
|
|
|
|
if padding is None:
|
|
padding = math.ceil(kernel_size - stride)
|
|
self.in_conv = nn.Conv2d(
|
|
in_channels=input_channels,
|
|
out_channels=block_out_channels[0],
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
)
|
|
if use_mid_block:
|
|
self.mid_block = UVitBlock(
|
|
block_out_channels[-1],
|
|
block_out_channels[-1],
|
|
num_res_blocks=layers_in_middle,
|
|
hidden_size=hidden_size,
|
|
hidden_dropout=dropout,
|
|
elementwise_affine=elementwise_affine,
|
|
norm_eps=norm_eps,
|
|
use_bias=use_bias,
|
|
downsample=False,
|
|
upsample=False,
|
|
stride=1,
|
|
res_ffn_factor=4,
|
|
)
|
|
else:
|
|
self.mid_block = None
|
|
|
|
def get_num_extra_tensors(self):
|
|
return 2
|
|
|
|
def forward(self, x, timesteps):
|
|
|
|
bs = x.shape[0]
|
|
dtype = x.dtype
|
|
|
|
t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype)
|
|
t_emb = self.time_embed(t_emb)
|
|
x_emb = self.in_conv(x)
|
|
|
|
if self.mid_block is not None:
|
|
x_emb = self.mid_block(x_emb, t_emb)
|
|
|
|
hs = [x_emb]
|
|
return x_emb, t_emb, hs
|
|
|
|
|
|
class ShallowUViTDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=768,
|
|
out_channels=3,
|
|
block_out_channels: Tuple[int] = (768,),
|
|
upsamples=2,
|
|
layers_in_middle=2,
|
|
hidden_size=2048,
|
|
elementwise_affine=True,
|
|
norm_eps=1e-6,
|
|
use_bias=True,
|
|
dropout=0.0,
|
|
use_mid_block=True,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
if use_mid_block:
|
|
self.mid_block = UVitBlock(
|
|
in_channels + block_out_channels[-1],
|
|
block_out_channels[
|
|
-1
|
|
], # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false.
|
|
num_res_blocks=layers_in_middle,
|
|
hidden_size=hidden_size,
|
|
hidden_dropout=dropout,
|
|
elementwise_affine=elementwise_affine,
|
|
norm_eps=norm_eps,
|
|
use_bias=use_bias,
|
|
downsample=False,
|
|
upsample=False,
|
|
stride=1,
|
|
res_ffn_factor=4,
|
|
)
|
|
else:
|
|
self.mid_block = None
|
|
self.out_convs = nn.ModuleList()
|
|
for rank in range(upsamples):
|
|
if rank == upsamples - 1:
|
|
curr_out_channels = out_channels
|
|
else:
|
|
curr_out_channels = block_out_channels[-1]
|
|
if rank == 0:
|
|
curr_in_channels = block_out_channels[-1] + in_channels
|
|
else:
|
|
curr_in_channels = block_out_channels[-1]
|
|
self.out_convs.append(
|
|
Unpatchify(
|
|
curr_in_channels,
|
|
curr_out_channels,
|
|
patch_size=2,
|
|
bias=use_bias,
|
|
elementwise_affine=elementwise_affine,
|
|
eps=norm_eps,
|
|
)
|
|
)
|
|
self.input_norm = RMSNorm(in_channels, norm_eps)
|
|
|
|
def forward(self, x, hs, t_emb):
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = self.input_norm(x)
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
x = torch.cat([x, hs.pop()], dim=1)
|
|
if self.mid_block is not None:
|
|
x = self.mid_block(x, t_emb)
|
|
for out_conv in self.out_convs:
|
|
x = out_conv(x)
|
|
assert len(hs) == 0
|
|
return x
|