add janusflow

This commit is contained in:
Xingchao Liu 2024-11-13 11:39:22 +08:00
parent 7a388c4f8b
commit b01013575f
10 changed files with 2982 additions and 24 deletions

283
README.md
View File

@ -6,6 +6,12 @@
<img src="images/logo.svg" width="60%" alt="DeepSeek LLM" />
</div>
<hr>
<div align="center">
<h1>🚀 Janus-Series: Unified Multimodal Understanding and Generation Models</h1>
</div>
<div align="center">
<a href="https://www.deepseek.com/" target="_blank">
@ -45,30 +51,44 @@
<p align="center">
<a href="#3-model-download"><b>📥 Model Download</b></a> |
<a href="#4-quick-start"><b>⚡ Quick Start</b></a> |
<a href="#5-license"><b>📜 License</b></a> |
<a href="#6-citation"><b>📖 Citation</b></a> <br>
<a href="https://arxiv.org/abs/2410.13848"><b>📄 Paper Link</b></a> |
<a href="https://huggingface.co/spaces/deepseek-ai/Janus-1.3B"><b>🤗 Online Demo</b></a>
<a href="#2-model-download"><b>📥 Model Download</b></a> |
<a href="#3-quick-start"><b>⚡ Quick Start</b></a> |
<a href="#4-license"><b>📜 License</b></a> |
<a href="#5-citation"><b>📖 Citation</b></a> <br>
<!-- 📄 Paper Link (<a href="https://arxiv.org/abs/2410.13848"><b>Janus</b></a>, <a href="https://arxiv.org/abs/2410.13848"><b>JanusFlow</b></a>) | -->
🤗 Online Demo (<a href="https://huggingface.co/spaces/deepseek-ai/Janus-1.3B"><b>Janus</b></a>, <a href="https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B"><b>JanusFlow</b></a>)
</p>
## News
**2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow).
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
## 1. Introduction
Janus is a novel autoregressive framework that unifies multimodal understanding and generation. It addresses the limitations of previous approaches by decoupling visual encoding into separate pathways, while still utilizing a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoders roles in understanding and generation, but also enhances the frameworks flexibility. Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.
<a href="https://arxiv.org/abs/2410.13848"><b>Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation</b></a>
**Janus** is a novel autoregressive framework that unifies multimodal understanding and generation. It addresses the limitations of previous approaches by decoupling visual encoding into separate pathways, while still utilizing a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoders roles in understanding and generation, but also enhances the frameworks flexibility. Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.
<div align="center">
<img alt="image" src="images/teaser.png" style="width:90%;">
</div>
## 2. News
<a href="https://arxiv.org/abs/2411.07975"><b>JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation</b></a>
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
**JanusFlow** introduces a minimalist architecture that integrates autoregressive language models with rectified flow, a state-of-the-art method in generative modeling. Our key finding demonstrates that rectified flow can be straightforwardly trained within the large language model framework, eliminating the need for complex architectural modifications. Extensive experiments show that JanusFlow achieves comparable or superior performance to specialized models in their respective domains, while significantly outperforming existing unified approaches across standard benchmarks. This work represents a step toward more efficient and versatile vision-language models.
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
<div align="center">
<img alt="image" src="images/teaser_janusflow.png" style="width:90%;">
</div>
## 3. Model Download
## 2. Model Download
We release Janus to the public to support a broader and more diverse range of research within both academic and commercial communities.
Please note that the use of this model is subject to the terms outlined in [License section](#5-license). Commercial usage is
@ -79,11 +99,15 @@ permitted under these terms.
| Model | Sequence Length | Download |
|-----------------------|-----------------|-----------------------------------------------------------------------------|
| Janus-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/Janus-1.3B) |
| JanusFlow-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/JanusFlow-1.3B) |
## 4. Quick Start
## 3. Quick Start
<details>
<summary><h3>Janus</h3></summary>
### Installation
@ -278,26 +302,237 @@ To test the server, you can open another terminal and run:
```
python demo/fastapi_client.py
```
</details>
<details>
<summary><h3>JanusFlow</h3></summary>
### Installation
On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command:
```shell
pip install -e .
pip install diffusers[torch]
```
### 🤗 Huggingface Online Demo
Check out the demo in [this link](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B).
### Simple Inference Example
#### Multimodal Understanding
```python
import torch
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
# specify the path to the model
model_path = "deepseek-ai/JanusFlow-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt = MultiModalityCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
conversation = [
{
"role": "User",
"content": "<image_placeholder>\nConvert the formula into latex code.",
"images": ["images/equation.png"],
},
{"role": "Assistant", "content": ""},
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)
# # run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# # run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)
```
#### Text-to-Image Generation
```python
import os
import PIL.Image
import torch
import numpy as np
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
import torchvision
## 5. License
# specify the path to the model
model_path = "deepseek-ai/JanusFlow-1.3B"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt = MultiModalityCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
from diffusers.models import AutoencoderKL
# remember to use bfloat16 dtype, this vae doesn't work with fp16
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = vae.to(torch.bfloat16).cuda().eval()
conversation = [
{
"role": "User",
"content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
},
{"role": "Assistant", "content": ""},
]
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=conversation,
sft_format=vl_chat_processor.sft_format,
system_prompt="",
)
prompt = sft_format + vl_chat_processor.image_gen_tag
@torch.inference_mode()
def generate(
mmgpt: MultiModalityCausalLM,
vl_chat_processor: VLChatProcessor,
prompt: str,
cfg_weight: float = 5.0,
num_inference_steps: int = 30,
batchsize: int = 5
):
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
tokens = torch.stack([input_ids] * 2 * batchsize).cuda()
tokens[batchsize:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
# we remove the last <bog> token and replace it with t_emb later
inputs_embeds = inputs_embeds[:, :-1, :]
# generate with rectified flow ode
# step 1: encode with vision_gen_enc
z = torch.randn((batchsize, 4, 48, 48), dtype=torch.bfloat16).cuda()
dt = 1.0 / num_inference_steps
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
# step 2: run ode
attention_mask = torch.ones((2*batchsize, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
attention_mask[batchsize:, 1:inputs_embeds.shape[1]] = 0
attention_mask = attention_mask.int()
for step in range(num_inference_steps):
# prepare inputs for the llm
z_input = torch.cat([z, z], dim=0) # for cfg
t = step / num_inference_steps * 1000.
t = torch.tensor([t] * z_input.shape[0]).to(dt)
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
# input to the llm
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
if step == 0:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=None)
past_key_values = []
for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
past_key_values = tuple(past_key_values)
else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
hidden_states = outputs.last_hidden_state
# transform hidden_states back to v
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
v_cond, v_uncond = torch.chunk(v, 2)
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
z = z + dt * v
# step 3: decode with vision_gen_dec and sdxl vae
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
os.makedirs('generated_samples', exist_ok=True)
save_path = os.path.join('generated_samples', "img.jpg")
torchvision.utils.save_image(decoded_image.clip_(-1.0, 1.0)*0.5+0.5, save_path)
generate(
vl_gpt,
vl_chat_processor,
prompt,
cfg_weight=2.0,
num_inference_steps=30,
batchsize=5
)
```
### Gradio Demo
For the local gradio demo, you can run with the following command:
```
pip install -e .[gradio]
python demo/app_janusflow.py
```
Have Fun!
</details>
## 4. License
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).
## 6. Citation
## 5. Citation
```
@misc{wu2024janus,
title={Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation},
author={Chengyue Wu and Xiaokang Chen and Zhiyu Wu and Yiyang Ma and Xingchao Liu and Zizheng Pan and Wen Liu and Zhenda Xie and Xingkai Yu and Chong Ruan and Ping Luo},
year={2024},
eprint={2410.13848},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2410.13848},
```bibtex
@article{wu2024janus,
title={Janus: Decoupling visual encoding for unified multimodal understanding and generation},
author={Wu, Chengyue and Chen, Xiaokang and Wu, Zhiyu and Ma, Yiyang and Liu, Xingchao and Pan, Zizheng and Liu, Wen and Xie, Zhenda and Yu, Xingkai and Ruan, Chong and others},
journal={arXiv preprint arXiv:2410.13848},
year={2024}
}
@misc{ma2024janusflow,
title={JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation},
author={Yiyang Ma and Xingchao Liu and Xiaokang Chen and Wen Liu and Chengyue Wu and Zhiyu Wu and Zizheng Pan and Zhenda Xie and Haowei Zhang and Xingkai yu and Liang Zhao and Yisong Wang and Jiaying Liu and Chong Ruan},
journal={arXiv preprint arXiv:2411.07975},
year={2024}
}
```
## 7. Contact
## 6. Contact
If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com).

248
demo/app_janusflow.py Normal file
View File

@ -0,0 +1,248 @@
import gradio as gr
import torch
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers.models import AutoencoderKL
import numpy as np
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model and processor
model_path = "deepseek-ai/JanusFlow-1.3B"
model_path = "/weka-jd/prod/jupyter/liuxingchao/notebooks/janus/final_converted_ckpt_ema"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
# remember to use bfloat16 dtype, this vae doesn't work with fp16
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
# Multimodal Understanding function
@torch.inference_mode()
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
torch.cuda.empty_cache()
# set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
conversation = [
{
"role": "User",
"content": f"<image_placeholder>\n{question}",
"images": [image],
},
{"role": "Assistant", "content": ""},
]
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False if temperature == 0 else True,
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
@torch.inference_mode()
def generate(
input_ids,
cfg_weight: float = 2.0,
num_inference_steps: int = 30
):
# we generate 5 images at a time, *2 for CFG
tokens = torch.stack([input_ids] * 10).cuda()
tokens[5:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
print(inputs_embeds.shape)
# we remove the last <bog> token and replace it with t_emb later
inputs_embeds = inputs_embeds[:, :-1, :]
# generate with rectified flow ode
# step 1: encode with vision_gen_enc
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
dt = 1.0 / num_inference_steps
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
# step 2: run ode
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
attention_mask = attention_mask.int()
for step in range(num_inference_steps):
# prepare inputs for the llm
z_input = torch.cat([z, z], dim=0) # for cfg
t = step / num_inference_steps * 1000.
t = torch.tensor([t] * z_input.shape[0]).to(dt)
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
# input to the llm
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
if step == 0:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=None)
past_key_values = []
for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
past_key_values = tuple(past_key_values)
else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
hidden_states = outputs.last_hidden_state
# transform hidden_states back to v
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
v_cond, v_uncond = torch.chunk(v, 2)
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
z = z + dt * v
# step 3: decode with vision_gen_dec and sdxl vae
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
images = ((images+1) / 2. * 255).astype(np.uint8)
return images
def unpack(dec, width, height, parallel_size=5):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
def generate_image(prompt,
seed=None,
guidance=5,
num_inference_steps=30):
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
with torch.no_grad():
messages = [{'role': 'User', 'content': prompt},
{'role': 'Assistant', 'content': ''}]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt='')
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
images = generate(input_ids,
cfg_weight=guidance,
num_inference_steps=num_inference_steps)
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown(value="# Multimodal Understanding")
# with gr.Row():
with gr.Row():
image_input = gr.Image()
with gr.Column():
question_input = gr.Textbox(label="Question")
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
understanding_button = gr.Button("Chat")
understanding_output = gr.Textbox(label="Response")
examples_inpainting = gr.Examples(
label="Multimodal Understanding examples",
examples=[
[
"explain this meme",
"./images/doge.png",
],
[
"Convert the formula into latex code.",
"./images/equation.png",
],
],
inputs=[question_input, image_input],
)
gr.Markdown(value="# Text-to-Image Generation")
with gr.Row():
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
prompt_input = gr.Textbox(label="Prompt")
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
generation_button = gr.Button("Generate Images")
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
examples_t2i = gr.Examples(
label="Text to image generation examples.",
examples=[
"Master shifu racoon wearing drip attire as a street gangster.",
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
],
inputs=prompt_input,
)
understanding_button.click(
multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
outputs=understanding_output
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
outputs=image_output
)
demo.launch(share=True)

View File

@ -0,0 +1,31 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# check if python version is above 3.10
import sys
if sys.version_info >= (3, 10):
print("Python version is above 3.10, patching the collections module.")
# Monkey patch collections
import collections
import collections.abc
for type_name in collections.abc.__all__:
setattr(collections, type_name, getattr(collections.abc, type_name))

View File

@ -0,0 +1,28 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from .image_processing_vlm import VLMImageProcessor
from .modeling_vlm import MultiModalityCausalLM
from .processing_vlm import VLChatProcessor
__all__ = [
"VLMImageProcessor",
"VLChatProcessor",
"MultiModalityCausalLM",
]

View File

@ -0,0 +1,122 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from typing import Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchvision.transforms
from einops import rearrange
from janus.janusflow.models.siglip_vit import create_siglip_vit
class CLIPVisionTower(nn.Module):
def __init__(
self,
model_name: str = "siglip_large_patch16_384",
image_size: Union[Tuple[int, int], int] = 336,
select_feature: str = "patch",
select_layer: int = -2,
select_layers: list = None,
ckpt_path: str = "",
pixel_mean: Optional[List[float]] = None,
pixel_std: Optional[List[float]] = None,
**kwargs,
):
super().__init__()
self.model_name = model_name
self.select_feature = select_feature
self.select_layer = select_layer
self.select_layers = select_layers
vision_tower_params = {
"model_name": model_name,
"image_size": image_size,
"ckpt_path": ckpt_path,
"select_layer": select_layer,
}
vision_tower_params.update(kwargs)
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
vision_tower_params
)
if pixel_mean is not None and pixel_std is not None:
image_norm = torchvision.transforms.Normalize(
mean=pixel_mean, std=pixel_std
)
else:
image_norm = None
self.image_norm = image_norm
def build_vision_tower(self, vision_tower_params):
if self.model_name.startswith("siglip"):
self.select_feature = "same"
vision_tower = create_siglip_vit(**vision_tower_params)
forward_kwargs = dict()
elif self.model_name.startswith("sam"):
vision_tower = create_sam_vit(**vision_tower_params)
forward_kwargs = dict()
else: # huggingface
from transformers import CLIPVisionModel
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
forward_kwargs = dict(output_hidden_states=True)
return vision_tower, forward_kwargs
def feature_select(self, image_forward_outs):
if isinstance(image_forward_outs, torch.Tensor):
# the output has been the self.select_layer"s features
image_features = image_forward_outs
else:
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
# if the output has cls_token
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
elif self.select_feature == "same":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def forward(self, images):
"""
Args:
images (torch.Tensor): [b, 3, H, W]
Returns:
image_features (torch.Tensor): [b, n_patch, d]
"""
if self.image_norm is not None:
images = self.image_norm(images)
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
image_features = self.feature_select(image_forward_outs)
return image_features

View File

@ -0,0 +1,208 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from typing import List, Tuple, Union
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional
from PIL import Image
from transformers import AutoImageProcessor, PretrainedConfig
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import to_numpy_array
from transformers.utils import logging
logger = logging.get_logger(__name__)
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
self.image_size = image_size
self.min_size = min_size
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
super().__init__(**kwargs)
class VLMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (
0.48145466,
0.4578275,
0.40821073,
),
image_std: Union[Tuple[float, float, float], List[float]] = (
0.26862954,
0.26130258,
0.27577711,
),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.rescale_factor = rescale_factor
self.image_mean = image_mean
self.image_std = image_std
self.min_size = min_size
self.do_normalize = do_normalize
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(self, pil_img: Image) -> np.ndarray:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width, height = pil_img.size
max_size = max(width, height)
size = [
max(int(height / max_size * self.image_size), self.min_size),
max(int(width / max_size * self.image_size), self.min_size),
]
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
print(f"orig size = {pil_img.size}, new size = {size}")
raise ValueError("Invalid size!")
pil_img = torchvision.transforms.functional.resize(
pil_img,
size,
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
antialias=True,
)
pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
# [H, W, 3] -> [3, H, W]
x = np.transpose(x, (2, 0, 1))
return x
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
images: List[np.ndarray] = [self.resize(image) for image in images]
# resacle from [0, 255] -> [0, 1]
images = [
self.rescale(
image=image,
scale=self.rescale_factor,
input_data_format="channels_first",
)
for image in images
]
# normalize
if self.do_normalize:
images = [
self.normalize(
image=image,
mean=self.image_mean,
std=self.image_std,
input_data_format="channels_first",
)
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@property
def default_shape(self):
return [3, self.image_size, self.image_size]
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
if __name__ == "__main__":
image_processor = VLMImageProcessor(
image_size=1024,
image_mean=IMAGENET_INCEPTION_MEAN,
image_std=IMAGENET_INCEPTION_STD,
do_normalize=True,
)

View File

@ -0,0 +1,226 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from attrdict import AttrDict
from einops import rearrange
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedModel,
LlamaConfig,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from janus.janusflow.models.clip_encoder import CLIPVisionTower
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
import torch.nn as nn
def model_name_to_cls(cls_name):
if "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "ShallowUViTEncoder" in cls_name:
cls = ShallowUViTEncoder
elif "ShallowUViTDecoder" in cls_name:
cls = ShallowUViTDecoder
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls
class VisionUnderstandEncoderConfig(PretrainedConfig):
model_type = "vision_und_enc"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class VisionGenerationEncoderConfig(PretrainedConfig):
model_type = "vision_gen_enc"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class VisionGenerationDecoderConfig(PretrainedConfig):
model_type = "vision_gen_dec"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_und_enc_config: VisionUnderstandEncoderConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_und_enc_config = kwargs.get("vision_und_enc_config", {})
self.vision_und_enc_config = VisionUnderstandEncoderConfig(
**vision_und_enc_config
)
vision_gen_enc_config = kwargs.get("vision_gen_enc_config", {})
self.vision_gen_enc_config = VisionGenerationEncoderConfig(
**vision_gen_enc_config
)
vision_gen_dec_config = kwargs.get("vision_gen_dec_config", {})
self.vision_gen_dec_config = VisionGenerationDecoderConfig(
**vision_gen_dec_config
)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class MultiModalityPreTrainedModel(PreTrainedModel):
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
# vision understanding encoder
vision_und_enc_config = config.vision_und_enc_config
vision_und_enc_cls = model_name_to_cls(vision_und_enc_config.cls)
self.vision_und_enc_model = vision_und_enc_cls(**vision_und_enc_config.params)
# vision understanding aligner
self.vision_und_enc_aligner = nn.Linear(1024, 2048, bias=True)
# begin of understanding embedding
self.beg_of_und_embed = nn.Parameter(torch.zeros(1, 2048))
# vision generation encoder
vision_gen_enc_config = config.vision_gen_enc_config
vision_gen_enc_cls = model_name_to_cls(vision_gen_enc_config.cls)
self.vision_gen_enc_model = vision_gen_enc_cls(**vision_gen_enc_config.params)
# vision generation encoder aligner
self.vision_gen_enc_aligner = nn.Linear(768, 2048, bias=True)
# vision generation decoder
vision_gen_dec_config = config.vision_gen_dec_config
vision_gen_dec_cls = model_name_to_cls(vision_gen_dec_config.cls)
self.vision_gen_dec_model = vision_gen_dec_cls(**vision_gen_dec_config.params)
# language model
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
# vision generation decoder aligner
self.vision_gen_dec_aligner_norm = LlamaRMSNorm(
2048, eps=language_config.rms_norm_eps
)
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor,
**kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.vision_und_enc_model(images)
images_embeds = self.vision_und_enc_aligner(images_embeds)
# print(images_embeds.shape, self.beg_of_und_embed.shape, images_seq_mask.shape, input_ids.shape)
beg_of_und_embed = self.beg_of_und_embed[0].detach().clone()
images_embeds = torch.cat(
[
beg_of_und_embed.view(1, 1, -1).repeat(images_embeds.shape[0], 1, 1),
images_embeds,
],
dim=1,
)
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
AutoConfig.register("vision_und_enc", VisionUnderstandEncoderConfig)
AutoConfig.register("vision_gen_enc", VisionGenerationEncoderConfig)
AutoConfig.register("vision_gen_dec", VisionGenerationDecoderConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)

View File

@ -0,0 +1,455 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from dataclasses import dataclass
from typing import Dict, List
import torch
from PIL.Image import Image
from transformers import LlamaTokenizerFast
from transformers.processing_utils import ProcessorMixin
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
from janus.utils.conversation import get_conv_template
class DictOutput(object):
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_und_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
@dataclass
class BatchedVLChatProcessorOutput(DictOutput):
sft_format: List[str]
input_ids: torch.Tensor
pixel_values: torch.Tensor
attention_mask: torch.Tensor
images_seq_mask: torch.BoolTensor
images_emb_mask: torch.BoolTensor
def to(self, device, dtype=torch.bfloat16):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.images_seq_mask = self.images_seq_mask.to(device)
self.images_emb_mask = self.images_emb_mask.to(device)
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
return self
class VLChatProcessor(ProcessorMixin):
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["image_processor", "tokenizer"]
system_prompt = (
"You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language."
)
def __init__(
self,
image_processor: VLMImageProcessor,
tokenizer: LlamaTokenizerFast,
image_tag: str = "<image_placeholder>",
image_start_tag: str = "<begin_of_image>",
image_end_tag: str = "<end_of_image>",
image_gen_tag: str = "<begin▁of▁generation>",
num_image_tokens: int = 576,
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.image_processor = image_processor
self.tokenizer = tokenizer
image_id = self.tokenizer.vocab.get(image_tag)
if image_id is None:
special_tokens = [image_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
print(f"Add image tag = {image_tag} to the tokenizer")
image_gen_id = self.tokenizer.vocab.get(image_gen_tag)
if image_gen_id is None:
special_tokens = [image_gen_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
print(f"Add generation tag = {image_gen_tag} to the tokenizer")
assert image_start_tag is not None and image_end_tag is not None
boi_id = self.tokenizer.vocab.get(image_start_tag)
eoi_id = self.tokenizer.vocab.get(image_end_tag)
if boi_id is None:
special_tokens = [image_start_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
print(f"Add boi tag = {image_start_tag} to the tokenizer")
if eoi_id is None:
special_tokens = [image_end_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
print(f"Add eoi tag = {image_end_tag} to the tokenizer")
self.image_tag = image_tag
self.image_gen_tag = image_gen_tag
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
self.tokenizer.pad_token_id = self.tokenizer.vocab.get("<▁pad▁>")
super().__init__(
image_processor,
tokenizer,
image_tag,
num_image_tokens,
add_special_token,
sft_format,
mask_prompt,
ignore_id,
**kwargs,
)
def new_chat_template(self):
conv = get_conv_template(self.sft_format)
conv.set_system_message(self.system_prompt)
return conv
def apply_sft_template_for_multi_turn_prompts(
self,
conversations: List[Dict[str, str]],
sft_format: str = "deepseek",
system_prompt: str = "",
):
"""
Applies the SFT template to conversation.
An example of conversation:
conversation = [
{
"role": "User",
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
"images": [
"./multi-images/attribute_comparison_1.png",
"./multi-images/attribute_comparison_2.png"
]
},
{
"role": "Assistant",
"content": ""
}
]
Args:
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
Returns:
sft_prompt (str): The formatted text.
"""
conv = get_conv_template(sft_format)
conv.set_system_message(system_prompt)
for message in conversations:
conv.append_message(message["role"], message["content"].strip())
sft_prompt = conv.get_prompt().strip()
return sft_prompt
@property
def image_token(self):
return self.image_tag
@property
def image_id(self):
image_id = self.tokenizer.vocab.get(self.image_tag)
return image_id
@property
def image_start_id(self):
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
return image_start_id
@property
def image_end_id(self):
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
return image_end_id
@property
def image_start_token(self):
return self.image_start_tag
@property
def image_end_token(self):
return self.image_end_tag
@property
def pad_id(self):
pad_id = self.tokenizer.pad_token_id
if pad_id is None:
pad_id = self.tokenizer.eos_token_id
return pad_id
@property
def image_gen_id(self):
image_gen_id = self.tokenizer.vocab.get(self.image_gen_tag)
return image_gen_id
def add_image_token(
self,
image_indices: List[int],
input_ids: torch.LongTensor,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices = []
start = 0
for index in image_indices:
if self.add_special_token:
end = index + 1
else:
end = index
# original text tokens
input_slices.append(input_ids[start:end])
# add boi, image tokens, eoi and set the mask as False
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
input_slices.append(
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
)
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
start = index + 1
# the left part
input_slices.append(input_ids[start:])
# concat all slices
input_ids = torch.cat(input_slices, dim=0)
num_image_tokens = torch.IntTensor(
[self.num_image_tokens + 1] * len(image_indices)
)
# we add 1 to fit generation
return input_ids, num_image_tokens
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
if prompt is None:
# apply sft format
sft_format = self.apply_sft_template_for_multi_turn_prompts(
conversations=conversations,
sft_format=self.sft_format,
system_prompt=self.system_prompt,
)
else:
sft_format = prompt
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.BoolTensor = input_ids == self.image_id
image_indices = image_token_mask.nonzero()
input_ids, num_und_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
prepare = VLChatProcessorOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_und_image_tokens=num_und_image_tokens,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
force_batchify: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(
prompt=prompt, conversations=conversations, images=images
)
if force_batchify:
prepare = self.batchify([prepare])
return prepare
def batchify(
self, prepare_list: List[VLChatProcessorOutput]
) -> BatchedVLChatProcessorOutput:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size = len(prepare_list)
sft_format = []
n_images = []
seq_lens = []
for prepare in prepare_list:
# we only fill the images for understanding tasks into the mask
n_images.append(len(prepare.num_und_image_tokens))
seq_lens.append(len(prepare))
input_token_max_len = max(seq_lens)
max_n_images = max(1, max(n_images))
batched_input_ids = torch.full(
(batch_size, input_token_max_len), self.pad_id
).long() # FIXME
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
batched_pixel_values = torch.zeros(
(batch_size, max_n_images, *self.image_processor.default_shape)
).float()
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
batched_images_emb_mask = torch.zeros(
(
batch_size,
max_n_images,
self.num_image_tokens + 1,
) # add 1 to account for <image_beg>
).bool()
for i, prepare in enumerate(prepare_list):
input_ids = prepare.input_ids
seq_len = len(prepare)
n_image = len(prepare.num_und_image_tokens)
# left-padding
batched_attention_mask[i, -seq_len:] = 1
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
batched_images_seq_mask[i, -seq_len:] = (input_ids == self.image_id) | (
input_ids == self.image_start_id
)
if n_image > 0:
batched_pixel_values[i, :n_image] = prepare.pixel_values
for j, n_image_tokens in enumerate(prepare.num_und_image_tokens):
batched_images_emb_mask[i, j, :n_image_tokens] = True
sft_format.append(prepare.sft_format)
batched_prepares = BatchedVLChatProcessorOutput(
input_ids=batched_input_ids,
attention_mask=batched_attention_mask,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask,
sft_format=sft_format,
)
return batched_prepares

View File

@ -0,0 +1,691 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import (
Callable,
Dict,
Final,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import (
AttentionPoolLatent,
DropPath,
LayerType,
Mlp,
PatchDropout,
PatchEmbed,
resample_abs_pos_embed,
)
from timm.models._manipulate import checkpoint_seq, named_apply
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std) # noqa: E741
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
dtype = tensor.dtype
tensor_fp32 = tensor.float()
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
tensor_dtype = tensor_fp32.to(dtype=dtype)
tensor.copy_(tensor_dtype)
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, "init_weights"):
module.init_weights()
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
# self.fused_attn = use_fused_attn()
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class VisionTransformer(nn.Module):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size: Final[bool]
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal["", "avg", "token", "map"] = "token",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.0,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
ignore_head: bool = False,
) -> None:
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ("", "avg", "token", "map")
assert class_token or global_pool != "token"
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
# act_layer = get_act_layer(act_layer) or nn.GELU
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.no_embed_class = (
no_embed_class # don't embed prefix positions (includes reg)
)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.ignore_head = ignore_head
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
self.cls_token = (
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
)
self.reg_token = (
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
)
embed_len = (
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
)
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
if global_pool == "map":
AttentionPoolLatent.init_weights = init_weights
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
if weight_init != "skip":
self.init_weights(weight_init)
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
assert mode in ("jax", "jax_nlhb", "moco", "")
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
trunc_normal_(self.pos_embed, std=0.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
return {"pos_embed", "cls_token", "dist_token"}
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ("", "avg", "token", "map")
if global_pool == "map" and self.attn_pool is None:
assert (
False
), "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != "map " and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]
if return_prefix_tokens:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == "avg":
x = x[:, self.num_prefix_tokens :].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
if not self.ignore_head:
x = self.forward_head(x)
return x
@dataclass
class SigLIPVisionCfg:
width: int = 1152
layers: Union[Tuple[int, int, int, int], int] = 27
heads: int = 16
patch_size: int = 14
image_size: Union[Tuple[int, int], int] = 336
global_pool: str = "map"
mlp_ratio: float = 3.7362
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
SigLIP_MODEL_CONFIG = {
"siglip_so400m_patch14_384": {
"image_size": 336,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False,
},
"siglip_so400m_patch14_224": {
"image_size": 224,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False,
},
"siglip_large_patch16_384": {
"image_size": 384,
"patch_size": 16,
"width": 1024,
"layers": 24,
"heads": 16,
"mlp_ratio": 4,
"global_pool": "map",
"use_checkpoint": False,
},
"siglip_large_patch16_256": {
"image_size": 256,
"patch_size": 16,
"width": 1024,
"layers": 24,
"heads": 16,
"mlp_ratio": 4,
"global_pool": "map",
"use_checkpoint": False,
},
}
def create_siglip_vit(
model_name: str = "siglip_so400m_patch14_384",
image_size: int = 384,
select_layer: int = -1,
ckpt_path: str = "",
**kwargs,
):
assert (
model_name in SigLIP_MODEL_CONFIG.keys()
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
if select_layer <= 0:
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
else:
layers = min(vision_cfg.layers, select_layer)
model = VisionTransformer(
img_size=image_size,
patch_size=vision_cfg.patch_size,
embed_dim=vision_cfg.width,
depth=layers,
num_heads=vision_cfg.heads,
mlp_ratio=vision_cfg.mlp_ratio,
class_token=vision_cfg.class_token,
global_pool=vision_cfg.global_pool,
ignore_head=kwargs.get("ignore_head", True),
weight_init=kwargs.get("weight_init", "skip"),
num_classes=0,
)
if ckpt_path:
state_dict = torch.load(ckpt_path, map_location="cpu")
incompatible_keys = model.load_state_dict(state_dict, strict=False)
print(
f"SigLIP-ViT restores from {ckpt_path},\n"
f"\tincompatible_keys:', {incompatible_keys}."
)
return model

View File

@ -0,0 +1,714 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
import math
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from typing import Optional, Tuple, Union
import numpy as np
import torchvision
import torchvision.utils
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
class ImageHead(nn.Module):
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
super().__init__()
self.layer_id = layer_id
cfg = (
AttrDict(
norm_type="layernorm",
is_exp_norm=False,
sequence_parallel=False,
use_userbuffer=False,
norm_eps=1e-5,
norm_bias=True,
gradient_accumulation_fusion=True,
use_fp32_head_weight=False,
)
+ gpt_cfg
)
group = PG.tensor_parallel_group()
assert cfg.norm_type in [
"layernorm",
"rmsnorm",
], f"Norm type:{cfg.norm_type} not supported"
if cfg.norm_type == "rmsnorm":
self.norm = DropoutAddRMSNorm(
cfg.n_embed,
prenorm=False,
eps=cfg.norm_eps,
is_exp_norm=cfg.is_exp_norm,
sequence_parallel=cfg.sequence_parallel,
)
else:
self.norm = DropoutAddLayerNorm(
cfg.n_embed,
prenorm=False,
eps=cfg.norm_eps,
is_exp_norm=cfg.is_exp_norm,
sequence_parallel=cfg.sequence_parallel,
bias=cfg.norm_bias,
)
multiple_of = 256
if decoder_cfg.in_channels % multiple_of != 0:
warnings.warn(
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
)
dtype = default_dtype = torch.get_default_dtype()
if cfg.use_fp32_head_weight:
dtype = torch.float32
print(
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
end="",
flush=True,
)
torch.set_default_dtype(dtype)
self.head = ColumnParallelLinear(
cfg.n_embed,
decoder_cfg.in_channels,
bias=True,
group=group,
sequence_parallel=cfg.sequence_parallel,
use_userbuffer=cfg.use_userbuffer,
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
use_fp32_output=False,
)
torch.set_default_dtype(default_dtype)
self.use_fp32_head_weight = cfg.use_fp32_head_weight
def forward(
self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs
):
residual = None
if isinstance(input_args, tuple):
x, residual = input_args
else:
x = input_args
x = self.norm(x, residual)
if self.use_fp32_head_weight:
assert (
self.head.weight.dtype == torch.float32
), f"head.weight is {self.head.weight.dtype}"
x = x.float()
if images_split_mask is None:
logits = self.head(x)
else:
bs, n_images = images_split_mask.shape[:2]
n_embed = x.shape[-1]
images_embed = torch.masked_select(
x.unsqueeze(1), images_split_mask.unsqueeze(-1)
)
images_embed = images_embed.view((bs * n_images, -1, n_embed))
logits = self.head(images_embed)
return logits
class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1)
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 1,
name: str = "conv",
kernel_size=3,
stride=2,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
self.name = name
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = nn.Conv2d(
self.channels,
self.out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2
)
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
stride=2,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
self.stride = stride
if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")
conv = None
if use_conv_transpose:
if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(
channels,
self.out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = nn.Conv2d(
self.channels,
self.out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias,
)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(
self,
hidden_states: torch.Tensor,
output_size: Optional[int] = None,
*args,
**kwargs,
) -> torch.Tensor:
assert hidden_states.shape[1] == self.channels
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2
)
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
if output_size is None:
hidden_states = F.interpolate(
hidden_states, scale_factor=self.stride, mode="nearest"
)
else:
hidden_states = F.interpolate(
hidden_states, size=output_size, mode="nearest"
)
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class ConvNextBlock(nn.Module):
def __init__(
self,
channels,
norm_eps,
elementwise_affine,
use_bias,
hidden_dropout,
hidden_size,
res_ffn_factor: int = 4,
):
super().__init__()
self.depthwise = nn.Conv2d(
channels,
channels,
kernel_size=7,
padding=3,
groups=channels,
bias=use_bias,
)
self.norm = RMSNorm(channels, norm_eps)
self.channelwise_linear_1 = nn.Linear(
channels, int(channels * res_ffn_factor), bias=use_bias
)
self.channelwise_act = nn.GELU()
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
self.channelwise_linear_2 = nn.Linear(
int(channels * res_ffn_factor), channels, bias=use_bias
)
self.channelwise_dropout = nn.Dropout(hidden_dropout)
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
def forward(self, x, cond_embeds):
x_res = x
x = self.depthwise(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.channelwise_linear_1(x)
x = self.channelwise_act(x)
x = self.channelwise_norm(x)
x = self.channelwise_linear_2(x)
x = self.channelwise_dropout(x)
x = x.permute(0, 3, 1, 2)
x = x + x_res
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
# x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
x = torch.addcmul(
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
)
return x
class Patchify(nn.Module):
def __init__(
self,
in_channels,
block_out_channels,
patch_size,
bias,
elementwise_affine,
eps,
kernel_size=None,
):
super().__init__()
if kernel_size is None:
kernel_size = patch_size
self.patch_conv = nn.Conv2d(
in_channels,
block_out_channels,
kernel_size=kernel_size,
stride=patch_size,
bias=bias,
)
self.norm = RMSNorm(block_out_channels, eps)
def forward(self, x):
embeddings = self.patch_conv(x)
embeddings = embeddings.permute(0, 2, 3, 1)
embeddings = self.norm(embeddings)
embeddings = embeddings.permute(0, 3, 1, 2)
return embeddings
class Unpatchify(nn.Module):
def __init__(
self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps
):
super().__init__()
self.norm = RMSNorm(in_channels, eps)
self.unpatch_conv = nn.Conv2d(
in_channels,
out_channels * patch_size * patch_size,
kernel_size=1,
bias=bias,
)
self.pixel_shuffle = nn.PixelShuffle(patch_size)
self.patch_size = patch_size
def forward(self, x):
# [b, c, h, w]
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
x = self.unpatch_conv(x)
x = self.pixel_shuffle(x)
return x
class UVitBlock(nn.Module):
def __init__(
self,
channels,
out_channels,
num_res_blocks,
stride,
hidden_size,
hidden_dropout,
elementwise_affine,
norm_eps,
use_bias,
downsample: bool,
upsample: bool,
res_ffn_factor: int = 4,
seq_len=None,
concat_input=False,
original_input_channels=None,
use_zero=True,
norm_type="RMS",
):
super().__init__()
self.res_blocks = nn.ModuleList()
for i in range(num_res_blocks):
conv_block = ConvNextBlock(
channels,
norm_eps,
elementwise_affine,
use_bias,
hidden_dropout,
hidden_size,
res_ffn_factor=res_ffn_factor,
)
self.res_blocks.append(conv_block)
if downsample:
self.downsample = Downsample2D(
channels=channels,
out_channels=out_channels,
use_conv=True,
name="Conv2d_0",
kernel_size=3,
padding=1,
stride=stride,
norm_type="rms_norm",
eps=norm_eps,
elementwise_affine=elementwise_affine,
bias=use_bias,
)
else:
self.downsample = None
if upsample:
self.upsample = Upsample2D(
channels=channels,
out_channels=out_channels,
use_conv_transpose=False,
use_conv=True,
kernel_size=3,
padding=1,
stride=stride,
name="conv",
norm_type="rms_norm",
eps=norm_eps,
elementwise_affine=elementwise_affine,
bias=use_bias,
interpolate=True,
)
else:
self.upsample = None
def forward(self, x, emb, recompute=False):
for res_block in self.res_blocks:
x = res_block(x, emb)
if self.downsample is not None:
x = self.downsample(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class ShallowUViTEncoder(nn.Module):
def __init__(
self,
input_channels=3,
stride=4,
kernel_size=7,
padding=None,
block_out_channels=(768,),
layers_in_middle=2,
hidden_size=2048,
elementwise_affine=True,
use_bias=True,
norm_eps=1e-6,
dropout=0.0,
use_mid_block=True,
**kwargs,
):
super().__init__()
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
)
self.time_embed = TimestepEmbedding(
block_out_channels[0], hidden_size, sample_proj_bias=use_bias
)
if padding is None:
padding = math.ceil(kernel_size - stride)
self.in_conv = nn.Conv2d(
in_channels=input_channels,
out_channels=block_out_channels[0],
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
if use_mid_block:
self.mid_block = UVitBlock(
block_out_channels[-1],
block_out_channels[-1],
num_res_blocks=layers_in_middle,
hidden_size=hidden_size,
hidden_dropout=dropout,
elementwise_affine=elementwise_affine,
norm_eps=norm_eps,
use_bias=use_bias,
downsample=False,
upsample=False,
stride=1,
res_ffn_factor=4,
)
else:
self.mid_block = None
def get_num_extra_tensors(self):
return 2
def forward(self, x, timesteps):
bs = x.shape[0]
dtype = x.dtype
t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype)
t_emb = self.time_embed(t_emb)
x_emb = self.in_conv(x)
if self.mid_block is not None:
x_emb = self.mid_block(x_emb, t_emb)
hs = [x_emb]
return x_emb, t_emb, hs
class ShallowUViTDecoder(nn.Module):
def __init__(
self,
in_channels=768,
out_channels=3,
block_out_channels: Tuple[int] = (768,),
upsamples=2,
layers_in_middle=2,
hidden_size=2048,
elementwise_affine=True,
norm_eps=1e-6,
use_bias=True,
dropout=0.0,
use_mid_block=True,
**kwargs,
):
super().__init__()
if use_mid_block:
self.mid_block = UVitBlock(
in_channels + block_out_channels[-1],
block_out_channels[
-1
], # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false.
num_res_blocks=layers_in_middle,
hidden_size=hidden_size,
hidden_dropout=dropout,
elementwise_affine=elementwise_affine,
norm_eps=norm_eps,
use_bias=use_bias,
downsample=False,
upsample=False,
stride=1,
res_ffn_factor=4,
)
else:
self.mid_block = None
self.out_convs = nn.ModuleList()
for rank in range(upsamples):
if rank == upsamples - 1:
curr_out_channels = out_channels
else:
curr_out_channels = block_out_channels[-1]
if rank == 0:
curr_in_channels = block_out_channels[-1] + in_channels
else:
curr_in_channels = block_out_channels[-1]
self.out_convs.append(
Unpatchify(
curr_in_channels,
curr_out_channels,
patch_size=2,
bias=use_bias,
elementwise_affine=elementwise_affine,
eps=norm_eps,
)
)
self.input_norm = RMSNorm(in_channels, norm_eps)
def forward(self, x, hs, t_emb):
x = x.permute(0, 2, 3, 1)
x = self.input_norm(x)
x = x.permute(0, 3, 1, 2)
x = torch.cat([x, hs.pop()], dim=1)
if self.mid_block is not None:
x = self.mid_block(x, t_emb)
for out_conv in self.out_convs:
x = out_conv(x)
assert len(hs) == 0
return x