mirror of
https://github.com/deepseek-ai/Janus
synced 2024-12-27 06:11:54 +00:00
add janusflow
This commit is contained in:
parent
7a388c4f8b
commit
b01013575f
283
README.md
283
README.md
@ -6,6 +6,12 @@
|
||||
<img src="images/logo.svg" width="60%" alt="DeepSeek LLM" />
|
||||
</div>
|
||||
<hr>
|
||||
|
||||
<div align="center">
|
||||
<h1>🚀 Janus-Series: Unified Multimodal Understanding and Generation Models</h1>
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
<a href="https://www.deepseek.com/" target="_blank">
|
||||
@ -45,30 +51,44 @@
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="#3-model-download"><b>📥 Model Download</b></a> |
|
||||
<a href="#4-quick-start"><b>⚡ Quick Start</b></a> |
|
||||
<a href="#5-license"><b>📜 License</b></a> |
|
||||
<a href="#6-citation"><b>📖 Citation</b></a> <br>
|
||||
<a href="https://arxiv.org/abs/2410.13848"><b>📄 Paper Link</b></a> |
|
||||
<a href="https://huggingface.co/spaces/deepseek-ai/Janus-1.3B"><b>🤗 Online Demo</b></a>
|
||||
<a href="#2-model-download"><b>📥 Model Download</b></a> |
|
||||
<a href="#3-quick-start"><b>⚡ Quick Start</b></a> |
|
||||
<a href="#4-license"><b>📜 License</b></a> |
|
||||
<a href="#5-citation"><b>📖 Citation</b></a> <br>
|
||||
<!-- 📄 Paper Link (<a href="https://arxiv.org/abs/2410.13848"><b>Janus</b></a>, <a href="https://arxiv.org/abs/2410.13848"><b>JanusFlow</b></a>) | -->
|
||||
🤗 Online Demo (<a href="https://huggingface.co/spaces/deepseek-ai/Janus-1.3B"><b>Janus</b></a>, <a href="https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B"><b>JanusFlow</b></a>)
|
||||
</p>
|
||||
|
||||
|
||||
## News
|
||||
|
||||
**2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow).
|
||||
|
||||
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
|
||||
|
||||
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
|
||||
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
Janus is a novel autoregressive framework that unifies multimodal understanding and generation. It addresses the limitations of previous approaches by decoupling visual encoding into separate pathways, while still utilizing a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder’s roles in understanding and generation, but also enhances the framework’s flexibility. Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.
|
||||
<a href="https://arxiv.org/abs/2410.13848"><b>Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation</b></a>
|
||||
|
||||
**Janus** is a novel autoregressive framework that unifies multimodal understanding and generation. It addresses the limitations of previous approaches by decoupling visual encoding into separate pathways, while still utilizing a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder’s roles in understanding and generation, but also enhances the framework’s flexibility. Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.
|
||||
|
||||
<div align="center">
|
||||
<img alt="image" src="images/teaser.png" style="width:90%;">
|
||||
</div>
|
||||
|
||||
## 2. News
|
||||
<a href="https://arxiv.org/abs/2411.07975"><b>JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation</b></a>
|
||||
|
||||
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
|
||||
**JanusFlow** introduces a minimalist architecture that integrates autoregressive language models with rectified flow, a state-of-the-art method in generative modeling. Our key finding demonstrates that rectified flow can be straightforwardly trained within the large language model framework, eliminating the need for complex architectural modifications. Extensive experiments show that JanusFlow achieves comparable or superior performance to specialized models in their respective domains, while significantly outperforming existing unified approaches across standard benchmarks. This work represents a step toward more efficient and versatile vision-language models.
|
||||
|
||||
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
|
||||
<div align="center">
|
||||
<img alt="image" src="images/teaser_janusflow.png" style="width:90%;">
|
||||
</div>
|
||||
|
||||
|
||||
## 3. Model Download
|
||||
## 2. Model Download
|
||||
|
||||
We release Janus to the public to support a broader and more diverse range of research within both academic and commercial communities.
|
||||
Please note that the use of this model is subject to the terms outlined in [License section](#5-license). Commercial usage is
|
||||
@ -79,11 +99,15 @@ permitted under these terms.
|
||||
| Model | Sequence Length | Download |
|
||||
|-----------------------|-----------------|-----------------------------------------------------------------------------|
|
||||
| Janus-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/Janus-1.3B) |
|
||||
| JanusFlow-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/JanusFlow-1.3B) |
|
||||
|
||||
|
||||
|
||||
|
||||
## 4. Quick Start
|
||||
## 3. Quick Start
|
||||
|
||||
<details>
|
||||
<summary><h3>Janus</h3></summary>
|
||||
|
||||
### Installation
|
||||
|
||||
@ -278,26 +302,237 @@ To test the server, you can open another terminal and run:
|
||||
```
|
||||
python demo/fastapi_client.py
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><h3>JanusFlow</h3></summary>
|
||||
### Installation
|
||||
|
||||
On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command:
|
||||
|
||||
```shell
|
||||
pip install -e .
|
||||
pip install diffusers[torch]
|
||||
```
|
||||
|
||||
### 🤗 Huggingface Online Demo
|
||||
Check out the demo in [this link](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B).
|
||||
|
||||
### Simple Inference Example
|
||||
|
||||
#### Multimodal Understanding
|
||||
```python
|
||||
|
||||
import torch
|
||||
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
||||
from janus.utils.io import load_pil_images
|
||||
|
||||
# specify the path to the model
|
||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt = MultiModalityCausalLM.from_pretrained(
|
||||
model_path, trust_remote_code=True
|
||||
)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder>\nConvert the formula into latex code.",
|
||||
"images": ["images/equation.png"],
|
||||
},
|
||||
{"role": "Assistant", "content": ""},
|
||||
]
|
||||
|
||||
# load images and prepare for inputs
|
||||
pil_images = load_pil_images(conversation)
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# # run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
# # run the model to get the response
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||
print(f"{prepare_inputs['sft_format'][0]}", answer)
|
||||
|
||||
```
|
||||
|
||||
#### Text-to-Image Generation
|
||||
```python
|
||||
import os
|
||||
import PIL.Image
|
||||
import torch
|
||||
import numpy as np
|
||||
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
||||
import torchvision
|
||||
|
||||
|
||||
## 5. License
|
||||
# specify the path to the model
|
||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt = MultiModalityCausalLM.from_pretrained(
|
||||
model_path, trust_remote_code=True
|
||||
)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
from diffusers.models import AutoencoderKL
|
||||
# remember to use bfloat16 dtype, this vae doesn't work with fp16
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
||||
vae = vae.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
|
||||
},
|
||||
{"role": "Assistant", "content": ""},
|
||||
]
|
||||
|
||||
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||
conversations=conversation,
|
||||
sft_format=vl_chat_processor.sft_format,
|
||||
system_prompt="",
|
||||
)
|
||||
prompt = sft_format + vl_chat_processor.image_gen_tag
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
mmgpt: MultiModalityCausalLM,
|
||||
vl_chat_processor: VLChatProcessor,
|
||||
prompt: str,
|
||||
cfg_weight: float = 5.0,
|
||||
num_inference_steps: int = 30,
|
||||
batchsize: int = 5
|
||||
):
|
||||
input_ids = vl_chat_processor.tokenizer.encode(prompt)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
tokens = torch.stack([input_ids] * 2 * batchsize).cuda()
|
||||
tokens[batchsize:, 1:] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
# we remove the last <bog> token and replace it with t_emb later
|
||||
inputs_embeds = inputs_embeds[:, :-1, :]
|
||||
|
||||
# generate with rectified flow ode
|
||||
# step 1: encode with vision_gen_enc
|
||||
z = torch.randn((batchsize, 4, 48, 48), dtype=torch.bfloat16).cuda()
|
||||
|
||||
dt = 1.0 / num_inference_steps
|
||||
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
|
||||
|
||||
# step 2: run ode
|
||||
attention_mask = torch.ones((2*batchsize, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
|
||||
attention_mask[batchsize:, 1:inputs_embeds.shape[1]] = 0
|
||||
attention_mask = attention_mask.int()
|
||||
for step in range(num_inference_steps):
|
||||
# prepare inputs for the llm
|
||||
z_input = torch.cat([z, z], dim=0) # for cfg
|
||||
t = step / num_inference_steps * 1000.
|
||||
t = torch.tensor([t] * z_input.shape[0]).to(dt)
|
||||
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
|
||||
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
|
||||
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
|
||||
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
|
||||
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
|
||||
|
||||
# input to the llm
|
||||
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
|
||||
if step == 0:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=None)
|
||||
past_key_values = []
|
||||
for kv_cache in past_key_values:
|
||||
k, v = kv_cache[0], kv_cache[1]
|
||||
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
|
||||
past_key_values = tuple(past_key_values)
|
||||
else:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# transform hidden_states back to v
|
||||
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
|
||||
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
|
||||
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
|
||||
v_cond, v_uncond = torch.chunk(v, 2)
|
||||
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
|
||||
z = z + dt * v
|
||||
|
||||
# step 3: decode with vision_gen_dec and sdxl vae
|
||||
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
|
||||
|
||||
os.makedirs('generated_samples', exist_ok=True)
|
||||
save_path = os.path.join('generated_samples', "img.jpg")
|
||||
torchvision.utils.save_image(decoded_image.clip_(-1.0, 1.0)*0.5+0.5, save_path)
|
||||
|
||||
generate(
|
||||
vl_gpt,
|
||||
vl_chat_processor,
|
||||
prompt,
|
||||
cfg_weight=2.0,
|
||||
num_inference_steps=30,
|
||||
batchsize=5
|
||||
)
|
||||
```
|
||||
|
||||
### Gradio Demo
|
||||
For the local gradio demo, you can run with the following command:
|
||||
|
||||
```
|
||||
pip install -e .[gradio]
|
||||
|
||||
python demo/app_janusflow.py
|
||||
```
|
||||
|
||||
Have Fun!
|
||||
|
||||
</details>
|
||||
|
||||
## 4. License
|
||||
|
||||
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).
|
||||
|
||||
## 6. Citation
|
||||
## 5. Citation
|
||||
|
||||
```
|
||||
@misc{wu2024janus,
|
||||
title={Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation},
|
||||
author={Chengyue Wu and Xiaokang Chen and Zhiyu Wu and Yiyang Ma and Xingchao Liu and Zizheng Pan and Wen Liu and Zhenda Xie and Xingkai Yu and Chong Ruan and Ping Luo},
|
||||
year={2024},
|
||||
eprint={2410.13848},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2410.13848},
|
||||
```bibtex
|
||||
@article{wu2024janus,
|
||||
title={Janus: Decoupling visual encoding for unified multimodal understanding and generation},
|
||||
author={Wu, Chengyue and Chen, Xiaokang and Wu, Zhiyu and Ma, Yiyang and Liu, Xingchao and Pan, Zizheng and Liu, Wen and Xie, Zhenda and Yu, Xingkai and Ruan, Chong and others},
|
||||
journal={arXiv preprint arXiv:2410.13848},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@misc{ma2024janusflow,
|
||||
title={JanusFlow: Harmonizing Autoregression and Rectified Flow for Unified Multimodal Understanding and Generation},
|
||||
author={Yiyang Ma and Xingchao Liu and Xiaokang Chen and Wen Liu and Chengyue Wu and Zhiyu Wu and Zizheng Pan and Zhenda Xie and Haowei Zhang and Xingkai yu and Liang Zhao and Yisong Wang and Jiaying Liu and Chong Ruan},
|
||||
journal={arXiv preprint arXiv:2411.07975},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
## 7. Contact
|
||||
## 6. Contact
|
||||
|
||||
If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com).
|
||||
|
248
demo/app_janusflow.py
Normal file
248
demo/app_janusflow.py
Normal file
@ -0,0 +1,248 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
||||
from PIL import Image
|
||||
from diffusers.models import AutoencoderKL
|
||||
import numpy as np
|
||||
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||
model_path = "/weka-jd/prod/jupyter/liuxingchao/notebooks/janus/final_converted_ckpt_ema"
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
||||
|
||||
# remember to use bfloat16 dtype, this vae doesn't work with fp16
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
||||
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
||||
|
||||
# Multimodal Understanding function
|
||||
@torch.inference_mode()
|
||||
# Multimodal Understanding function
|
||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": f"<image_placeholder>\n{question}",
|
||||
"images": [image],
|
||||
},
|
||||
{"role": "Assistant", "content": ""},
|
||||
]
|
||||
|
||||
pil_images = [Image.fromarray(image)]
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
|
||||
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=512,
|
||||
do_sample=False if temperature == 0 else True,
|
||||
use_cache=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
input_ids,
|
||||
cfg_weight: float = 2.0,
|
||||
num_inference_steps: int = 30
|
||||
):
|
||||
# we generate 5 images at a time, *2 for CFG
|
||||
tokens = torch.stack([input_ids] * 10).cuda()
|
||||
tokens[5:, 1:] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
print(inputs_embeds.shape)
|
||||
|
||||
# we remove the last <bog> token and replace it with t_emb later
|
||||
inputs_embeds = inputs_embeds[:, :-1, :]
|
||||
|
||||
# generate with rectified flow ode
|
||||
# step 1: encode with vision_gen_enc
|
||||
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
|
||||
|
||||
dt = 1.0 / num_inference_steps
|
||||
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
|
||||
|
||||
# step 2: run ode
|
||||
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
|
||||
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
|
||||
attention_mask = attention_mask.int()
|
||||
for step in range(num_inference_steps):
|
||||
# prepare inputs for the llm
|
||||
z_input = torch.cat([z, z], dim=0) # for cfg
|
||||
t = step / num_inference_steps * 1000.
|
||||
t = torch.tensor([t] * z_input.shape[0]).to(dt)
|
||||
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
|
||||
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
|
||||
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
|
||||
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
|
||||
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
|
||||
|
||||
# input to the llm
|
||||
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
|
||||
if step == 0:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=None)
|
||||
past_key_values = []
|
||||
for kv_cache in past_key_values:
|
||||
k, v = kv_cache[0], kv_cache[1]
|
||||
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
|
||||
past_key_values = tuple(past_key_values)
|
||||
else:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# transform hidden_states back to v
|
||||
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
|
||||
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
|
||||
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
|
||||
v_cond, v_uncond = torch.chunk(v, 2)
|
||||
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
|
||||
z = z + dt * v
|
||||
|
||||
# step 3: decode with vision_gen_dec and sdxl vae
|
||||
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
|
||||
|
||||
images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
|
||||
images = ((images+1) / 2. * 255).astype(np.uint8)
|
||||
|
||||
return images
|
||||
|
||||
def unpack(dec, width, height, parallel_size=5):
|
||||
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
||||
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
||||
|
||||
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
||||
visual_img[:, :, :] = dec
|
||||
|
||||
return visual_img
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_image(prompt,
|
||||
seed=None,
|
||||
guidance=5,
|
||||
num_inference_steps=30):
|
||||
# Clear CUDA cache and avoid tracking gradients
|
||||
torch.cuda.empty_cache()
|
||||
# Set the seed for reproducible results
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
messages = [{'role': 'User', 'content': prompt},
|
||||
{'role': 'Assistant', 'content': ''}]
|
||||
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
||||
sft_format=vl_chat_processor.sft_format,
|
||||
system_prompt='')
|
||||
text = text + vl_chat_processor.image_start_tag
|
||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||
images = generate(input_ids,
|
||||
cfg_weight=guidance,
|
||||
num_inference_steps=num_inference_steps)
|
||||
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
|
||||
|
||||
|
||||
|
||||
# Gradio interface
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(value="# Multimodal Understanding")
|
||||
# with gr.Row():
|
||||
with gr.Row():
|
||||
image_input = gr.Image()
|
||||
with gr.Column():
|
||||
question_input = gr.Textbox(label="Question")
|
||||
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
||||
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
||||
|
||||
understanding_button = gr.Button("Chat")
|
||||
understanding_output = gr.Textbox(label="Response")
|
||||
|
||||
examples_inpainting = gr.Examples(
|
||||
label="Multimodal Understanding examples",
|
||||
examples=[
|
||||
[
|
||||
"explain this meme",
|
||||
"./images/doge.png",
|
||||
],
|
||||
[
|
||||
"Convert the formula into latex code.",
|
||||
"./images/equation.png",
|
||||
],
|
||||
],
|
||||
inputs=[question_input, image_input],
|
||||
)
|
||||
|
||||
|
||||
gr.Markdown(value="# Text-to-Image Generation")
|
||||
|
||||
|
||||
|
||||
with gr.Row():
|
||||
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
|
||||
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
|
||||
|
||||
prompt_input = gr.Textbox(label="Prompt")
|
||||
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
||||
|
||||
generation_button = gr.Button("Generate Images")
|
||||
|
||||
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
||||
|
||||
examples_t2i = gr.Examples(
|
||||
label="Text to image generation examples.",
|
||||
examples=[
|
||||
"Master shifu racoon wearing drip attire as a street gangster.",
|
||||
"A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
|
||||
"The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
|
||||
],
|
||||
inputs=prompt_input,
|
||||
)
|
||||
|
||||
understanding_button.click(
|
||||
multimodal_understanding,
|
||||
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
||||
outputs=understanding_output
|
||||
)
|
||||
|
||||
generation_button.click(
|
||||
fn=generate_image,
|
||||
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
|
||||
outputs=image_output
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
31
janus/janusflow/__init__.py
Normal file
31
janus/janusflow/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
# check if python version is above 3.10
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
print("Python version is above 3.10, patching the collections module.")
|
||||
# Monkey patch collections
|
||||
import collections
|
||||
import collections.abc
|
||||
|
||||
for type_name in collections.abc.__all__:
|
||||
setattr(collections, type_name, getattr(collections.abc, type_name))
|
28
janus/janusflow/models/__init__.py
Normal file
28
janus/janusflow/models/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from .image_processing_vlm import VLMImageProcessor
|
||||
from .modeling_vlm import MultiModalityCausalLM
|
||||
from .processing_vlm import VLChatProcessor
|
||||
|
||||
__all__ = [
|
||||
"VLMImageProcessor",
|
||||
"VLChatProcessor",
|
||||
"MultiModalityCausalLM",
|
||||
]
|
122
janus/janusflow/models/clip_encoder.py
Normal file
122
janus/janusflow/models/clip_encoder.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms
|
||||
from einops import rearrange
|
||||
|
||||
from janus.janusflow.models.siglip_vit import create_siglip_vit
|
||||
|
||||
|
||||
class CLIPVisionTower(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "siglip_large_patch16_384",
|
||||
image_size: Union[Tuple[int, int], int] = 336,
|
||||
select_feature: str = "patch",
|
||||
select_layer: int = -2,
|
||||
select_layers: list = None,
|
||||
ckpt_path: str = "",
|
||||
pixel_mean: Optional[List[float]] = None,
|
||||
pixel_std: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.select_feature = select_feature
|
||||
self.select_layer = select_layer
|
||||
self.select_layers = select_layers
|
||||
|
||||
vision_tower_params = {
|
||||
"model_name": model_name,
|
||||
"image_size": image_size,
|
||||
"ckpt_path": ckpt_path,
|
||||
"select_layer": select_layer,
|
||||
}
|
||||
vision_tower_params.update(kwargs)
|
||||
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
||||
vision_tower_params
|
||||
)
|
||||
|
||||
if pixel_mean is not None and pixel_std is not None:
|
||||
image_norm = torchvision.transforms.Normalize(
|
||||
mean=pixel_mean, std=pixel_std
|
||||
)
|
||||
else:
|
||||
image_norm = None
|
||||
|
||||
self.image_norm = image_norm
|
||||
|
||||
def build_vision_tower(self, vision_tower_params):
|
||||
if self.model_name.startswith("siglip"):
|
||||
self.select_feature = "same"
|
||||
vision_tower = create_siglip_vit(**vision_tower_params)
|
||||
forward_kwargs = dict()
|
||||
|
||||
elif self.model_name.startswith("sam"):
|
||||
vision_tower = create_sam_vit(**vision_tower_params)
|
||||
forward_kwargs = dict()
|
||||
|
||||
else: # huggingface
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
||||
forward_kwargs = dict(output_hidden_states=True)
|
||||
|
||||
return vision_tower, forward_kwargs
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
if isinstance(image_forward_outs, torch.Tensor):
|
||||
# the output has been the self.select_layer"s features
|
||||
image_features = image_forward_outs
|
||||
else:
|
||||
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||
|
||||
if self.select_feature == "patch":
|
||||
# if the output has cls_token
|
||||
image_features = image_features[:, 1:]
|
||||
elif self.select_feature == "cls_patch":
|
||||
image_features = image_features
|
||||
elif self.select_feature == "same":
|
||||
image_features = image_features
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
return image_features
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): [b, 3, H, W]
|
||||
|
||||
Returns:
|
||||
image_features (torch.Tensor): [b, n_patch, d]
|
||||
"""
|
||||
|
||||
if self.image_norm is not None:
|
||||
images = self.image_norm(images)
|
||||
|
||||
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
||||
image_features = self.feature_select(image_forward_outs)
|
||||
return image_features
|
208
janus/janusflow/models/image_processing_vlm.py
Normal file
208
janus/janusflow/models/image_processing_vlm.py
Normal file
@ -0,0 +1,208 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms.functional
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, PretrainedConfig
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from transformers.image_utils import to_numpy_array
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||||
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
|
||||
class VLMImageProcessorConfig(PretrainedConfig):
|
||||
model_type = "deepseek_vlm"
|
||||
image_size: int
|
||||
min_size: int
|
||||
image_mean: Union[Tuple[float, float, float], List[float]]
|
||||
image_std: Union[Tuple[float, float, float], List[float]]
|
||||
rescale_factor: float
|
||||
do_normalize: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_size = image_size
|
||||
self.min_size = min_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class VLMImageProcessor(BaseImageProcessor):
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.48145466,
|
||||
0.4578275,
|
||||
0.40821073,
|
||||
),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (
|
||||
0.26862954,
|
||||
0.26130258,
|
||||
0.27577711,
|
||||
),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.min_size = min_size
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
if image_mean is None:
|
||||
self.background_color = (127, 127, 127)
|
||||
else:
|
||||
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||||
|
||||
def resize(self, pil_img: Image) -> np.ndarray:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||||
|
||||
Returns:
|
||||
x (np.ndarray): [3, self.image_size, self.image_size]
|
||||
"""
|
||||
|
||||
width, height = pil_img.size
|
||||
max_size = max(width, height)
|
||||
|
||||
size = [
|
||||
max(int(height / max_size * self.image_size), self.min_size),
|
||||
max(int(width / max_size * self.image_size), self.min_size),
|
||||
]
|
||||
|
||||
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||||
print(f"orig size = {pil_img.size}, new size = {size}")
|
||||
raise ValueError("Invalid size!")
|
||||
|
||||
pil_img = torchvision.transforms.functional.resize(
|
||||
pil_img,
|
||||
size,
|
||||
interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
pil_img = expand2square(pil_img, self.background_color)
|
||||
x = to_numpy_array(pil_img)
|
||||
|
||||
# [H, W, 3] -> [3, H, W]
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
|
||||
return x
|
||||
|
||||
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||||
# resize and pad to [self.image_size, self.image_size]
|
||||
# then convert from [H, W, 3] to [3, H, W]
|
||||
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||||
|
||||
# resacle from [0, 255] -> [0, 1]
|
||||
images = [
|
||||
self.rescale(
|
||||
image=image,
|
||||
scale=self.rescale_factor,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
# normalize
|
||||
if self.do_normalize:
|
||||
images = [
|
||||
self.normalize(
|
||||
image=image,
|
||||
mean=self.image_mean,
|
||||
std=self.image_std,
|
||||
input_data_format="channels_first",
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
return [3, self.image_size, self.image_size]
|
||||
|
||||
|
||||
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_processor = VLMImageProcessor(
|
||||
image_size=1024,
|
||||
image_mean=IMAGENET_INCEPTION_MEAN,
|
||||
image_std=IMAGENET_INCEPTION_STD,
|
||||
do_normalize=True,
|
||||
)
|
226
janus/janusflow/models/modeling_vlm.py
Normal file
226
janus/janusflow/models/modeling_vlm.py
Normal file
@ -0,0 +1,226 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from attrdict import AttrDict
|
||||
from einops import rearrange
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedModel,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from janus.janusflow.models.clip_encoder import CLIPVisionTower
|
||||
from janus.janusflow.models.uvit import ShallowUViTEncoder, ShallowUViTDecoder
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def model_name_to_cls(cls_name):
|
||||
|
||||
if "CLIPVisionTower" in cls_name:
|
||||
cls = CLIPVisionTower
|
||||
elif "ShallowUViTEncoder" in cls_name:
|
||||
cls = ShallowUViTEncoder
|
||||
elif "ShallowUViTDecoder" in cls_name:
|
||||
cls = ShallowUViTDecoder
|
||||
else:
|
||||
raise ValueError(f"class_name {cls_name} is invalid.")
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class VisionUnderstandEncoderConfig(PretrainedConfig):
|
||||
model_type = "vision_und_enc"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class VisionGenerationEncoderConfig(PretrainedConfig):
|
||||
model_type = "vision_gen_enc"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class VisionGenerationDecoderConfig(PretrainedConfig):
|
||||
model_type = "vision_gen_dec"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class MultiModalityConfig(PretrainedConfig):
|
||||
model_type = "multi_modality"
|
||||
vision_und_enc_config: VisionUnderstandEncoderConfig
|
||||
language_config: LlamaConfig
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_und_enc_config = kwargs.get("vision_und_enc_config", {})
|
||||
self.vision_und_enc_config = VisionUnderstandEncoderConfig(
|
||||
**vision_und_enc_config
|
||||
)
|
||||
|
||||
vision_gen_enc_config = kwargs.get("vision_gen_enc_config", {})
|
||||
self.vision_gen_enc_config = VisionGenerationEncoderConfig(
|
||||
**vision_gen_enc_config
|
||||
)
|
||||
|
||||
vision_gen_dec_config = kwargs.get("vision_gen_dec_config", {})
|
||||
self.vision_gen_dec_config = VisionGenerationDecoderConfig(
|
||||
**vision_gen_dec_config
|
||||
)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, LlamaConfig):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = LlamaConfig(**language_config)
|
||||
|
||||
|
||||
class MultiModalityPreTrainedModel(PreTrainedModel):
|
||||
config_class = MultiModalityConfig
|
||||
base_model_prefix = "multi_modality"
|
||||
_no_split_modules = []
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
|
||||
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
|
||||
def __init__(self, config: MultiModalityConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# vision understanding encoder
|
||||
vision_und_enc_config = config.vision_und_enc_config
|
||||
vision_und_enc_cls = model_name_to_cls(vision_und_enc_config.cls)
|
||||
self.vision_und_enc_model = vision_und_enc_cls(**vision_und_enc_config.params)
|
||||
|
||||
# vision understanding aligner
|
||||
self.vision_und_enc_aligner = nn.Linear(1024, 2048, bias=True)
|
||||
|
||||
# begin of understanding embedding
|
||||
self.beg_of_und_embed = nn.Parameter(torch.zeros(1, 2048))
|
||||
|
||||
# vision generation encoder
|
||||
vision_gen_enc_config = config.vision_gen_enc_config
|
||||
vision_gen_enc_cls = model_name_to_cls(vision_gen_enc_config.cls)
|
||||
self.vision_gen_enc_model = vision_gen_enc_cls(**vision_gen_enc_config.params)
|
||||
|
||||
# vision generation encoder aligner
|
||||
self.vision_gen_enc_aligner = nn.Linear(768, 2048, bias=True)
|
||||
|
||||
# vision generation decoder
|
||||
vision_gen_dec_config = config.vision_gen_dec_config
|
||||
vision_gen_dec_cls = model_name_to_cls(vision_gen_dec_config.cls)
|
||||
self.vision_gen_dec_model = vision_gen_dec_cls(**vision_gen_dec_config.params)
|
||||
|
||||
# language model
|
||||
language_config = config.language_config
|
||||
self.language_model = LlamaForCausalLM(language_config)
|
||||
|
||||
# vision generation decoder aligner
|
||||
self.vision_gen_dec_aligner_norm = LlamaRMSNorm(
|
||||
2048, eps=language_config.rms_norm_eps
|
||||
)
|
||||
self.vision_gen_dec_aligner = nn.Linear(2048, 768, bias=True)
|
||||
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
images_seq_mask: torch.LongTensor,
|
||||
images_emb_mask: torch.LongTensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids (torch.LongTensor): [b, T]
|
||||
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
||||
images_seq_mask (torch.BoolTensor): [b, T]
|
||||
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
||||
|
||||
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
||||
|
||||
Returns:
|
||||
input_embeds (torch.Tensor): [b, T, D]
|
||||
"""
|
||||
|
||||
bs, n = pixel_values.shape[0:2]
|
||||
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
||||
# [b x n, T2, D]
|
||||
images_embeds = self.vision_und_enc_model(images)
|
||||
images_embeds = self.vision_und_enc_aligner(images_embeds)
|
||||
# print(images_embeds.shape, self.beg_of_und_embed.shape, images_seq_mask.shape, input_ids.shape)
|
||||
beg_of_und_embed = self.beg_of_und_embed[0].detach().clone()
|
||||
images_embeds = torch.cat(
|
||||
[
|
||||
beg_of_und_embed.view(1, 1, -1).repeat(images_embeds.shape[0], 1, 1),
|
||||
images_embeds,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
# [b x n, T2, D] -> [b, n x T2, D]
|
||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
||||
# [b, n, T2] -> [b, n x T2]
|
||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
||||
|
||||
# [b, T, D]
|
||||
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
# replace with the image embeddings
|
||||
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
AutoConfig.register("vision_und_enc", VisionUnderstandEncoderConfig)
|
||||
AutoConfig.register("vision_gen_enc", VisionGenerationEncoderConfig)
|
||||
AutoConfig.register("vision_gen_dec", VisionGenerationDecoderConfig)
|
||||
AutoConfig.register("multi_modality", MultiModalityConfig)
|
||||
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
455
janus/janusflow/models/processing_vlm.py
Normal file
455
janus/janusflow/models/processing_vlm.py
Normal file
@ -0,0 +1,455 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import LlamaTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from janus.janusflow.models.image_processing_vlm import VLMImageProcessor
|
||||
from janus.utils.conversation import get_conv_template
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
sft_format: str
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
num_und_image_tokens: torch.IntTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedVLChatProcessorOutput(DictOutput):
|
||||
sft_format: List[str]
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_emb_mask: torch.BoolTensor
|
||||
|
||||
def to(self, device, dtype=torch.bfloat16):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.images_seq_mask = self.images_seq_mask.to(device)
|
||||
self.images_emb_mask = self.images_emb_mask.to(device)
|
||||
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
class VLChatProcessor(ProcessorMixin):
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
system_prompt = (
|
||||
"You are a helpful language and vision assistant. "
|
||||
"You are able to understand the visual content that the user provides, "
|
||||
"and assist the user with a variety of tasks using natural language."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: VLMImageProcessor,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
image_tag: str = "<image_placeholder>",
|
||||
image_start_tag: str = "<begin_of_image>",
|
||||
image_end_tag: str = "<end_of_image>",
|
||||
image_gen_tag: str = "<|begin▁of▁generation|>",
|
||||
num_image_tokens: int = 576,
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
image_id = self.tokenizer.vocab.get(image_tag)
|
||||
if image_id is None:
|
||||
special_tokens = [image_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Add image tag = {image_tag} to the tokenizer")
|
||||
|
||||
image_gen_id = self.tokenizer.vocab.get(image_gen_tag)
|
||||
if image_gen_id is None:
|
||||
special_tokens = [image_gen_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Add generation tag = {image_gen_tag} to the tokenizer")
|
||||
|
||||
assert image_start_tag is not None and image_end_tag is not None
|
||||
boi_id = self.tokenizer.vocab.get(image_start_tag)
|
||||
eoi_id = self.tokenizer.vocab.get(image_end_tag)
|
||||
if boi_id is None:
|
||||
special_tokens = [image_start_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Add boi tag = {image_start_tag} to the tokenizer")
|
||||
if eoi_id is None:
|
||||
special_tokens = [image_end_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Add eoi tag = {image_end_tag} to the tokenizer")
|
||||
|
||||
self.image_tag = image_tag
|
||||
self.image_gen_tag = image_gen_tag
|
||||
self.image_start_tag = image_start_tag
|
||||
self.image_end_tag = image_end_tag
|
||||
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.mask_prompt = mask_prompt
|
||||
self.ignore_id = ignore_id
|
||||
self.tokenizer.pad_token_id = self.tokenizer.vocab.get("<|▁pad▁|>")
|
||||
|
||||
super().__init__(
|
||||
image_processor,
|
||||
tokenizer,
|
||||
image_tag,
|
||||
num_image_tokens,
|
||||
add_special_token,
|
||||
sft_format,
|
||||
mask_prompt,
|
||||
ignore_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def new_chat_template(self):
|
||||
conv = get_conv_template(self.sft_format)
|
||||
conv.set_system_message(self.system_prompt)
|
||||
return conv
|
||||
|
||||
def apply_sft_template_for_multi_turn_prompts(
|
||||
self,
|
||||
conversations: List[Dict[str, str]],
|
||||
sft_format: str = "deepseek",
|
||||
system_prompt: str = "",
|
||||
):
|
||||
"""
|
||||
Applies the SFT template to conversation.
|
||||
|
||||
An example of conversation:
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
||||
"images": [
|
||||
"./multi-images/attribute_comparison_1.png",
|
||||
"./multi-images/attribute_comparison_2.png"
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "Assistant",
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
|
||||
Args:
|
||||
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
||||
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
||||
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
||||
|
||||
Returns:
|
||||
sft_prompt (str): The formatted text.
|
||||
"""
|
||||
|
||||
conv = get_conv_template(sft_format)
|
||||
conv.set_system_message(system_prompt)
|
||||
for message in conversations:
|
||||
conv.append_message(message["role"], message["content"].strip())
|
||||
sft_prompt = conv.get_prompt().strip()
|
||||
|
||||
return sft_prompt
|
||||
|
||||
@property
|
||||
def image_token(self):
|
||||
return self.image_tag
|
||||
|
||||
@property
|
||||
def image_id(self):
|
||||
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||||
return image_id
|
||||
|
||||
@property
|
||||
def image_start_id(self):
|
||||
image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
|
||||
return image_start_id
|
||||
|
||||
@property
|
||||
def image_end_id(self):
|
||||
image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
|
||||
return image_end_id
|
||||
|
||||
@property
|
||||
def image_start_token(self):
|
||||
return self.image_start_tag
|
||||
|
||||
@property
|
||||
def image_end_token(self):
|
||||
return self.image_end_tag
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
pad_id = self.tokenizer.pad_token_id
|
||||
if pad_id is None:
|
||||
pad_id = self.tokenizer.eos_token_id
|
||||
|
||||
return pad_id
|
||||
|
||||
@property
|
||||
def image_gen_id(self):
|
||||
image_gen_id = self.tokenizer.vocab.get(self.image_gen_tag)
|
||||
return image_gen_id
|
||||
|
||||
def add_image_token(
|
||||
self,
|
||||
image_indices: List[int],
|
||||
input_ids: torch.LongTensor,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||||
input_ids (torch.LongTensor): [N]
|
||||
|
||||
Returns:
|
||||
input_ids (torch.LongTensor): [N + image tokens]
|
||||
num_image_tokens (torch.IntTensor): [n_images]
|
||||
"""
|
||||
|
||||
input_slices = []
|
||||
|
||||
start = 0
|
||||
for index in image_indices:
|
||||
if self.add_special_token:
|
||||
end = index + 1
|
||||
else:
|
||||
end = index
|
||||
|
||||
# original text tokens
|
||||
input_slices.append(input_ids[start:end])
|
||||
|
||||
# add boi, image tokens, eoi and set the mask as False
|
||||
input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
|
||||
input_slices.append(
|
||||
self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
|
||||
)
|
||||
input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
|
||||
start = index + 1
|
||||
|
||||
# the left part
|
||||
input_slices.append(input_ids[start:])
|
||||
|
||||
# concat all slices
|
||||
input_ids = torch.cat(input_slices, dim=0)
|
||||
num_image_tokens = torch.IntTensor(
|
||||
[self.num_image_tokens + 1] * len(image_indices)
|
||||
)
|
||||
# we add 1 to fit generation
|
||||
|
||||
return input_ids, num_image_tokens
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
assert (
|
||||
prompt is None or conversations is None
|
||||
), "prompt and conversations cannot be used at the same time."
|
||||
|
||||
if prompt is None:
|
||||
# apply sft format
|
||||
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
||||
conversations=conversations,
|
||||
sft_format=self.sft_format,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
else:
|
||||
sft_format = prompt
|
||||
|
||||
# tokenize
|
||||
input_ids = self.tokenizer.encode(sft_format)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
# add image tokens to the input_ids
|
||||
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
||||
image_indices = image_token_mask.nonzero()
|
||||
|
||||
input_ids, num_und_image_tokens = self.add_image_token(
|
||||
image_indices=image_indices,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
# load images
|
||||
images_outputs = self.image_processor(images, return_tensors="pt")
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
sft_format=sft_format,
|
||||
input_ids=input_ids,
|
||||
pixel_values=images_outputs.pixel_values,
|
||||
num_und_image_tokens=num_und_image_tokens,
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
force_batchify: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
force_batchify (bool): force batchify the inputs;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
prepare = self.process_one(
|
||||
prompt=prompt, conversations=conversations, images=images
|
||||
)
|
||||
|
||||
if force_batchify:
|
||||
prepare = self.batchify([prepare])
|
||||
|
||||
return prepare
|
||||
|
||||
def batchify(
|
||||
self, prepare_list: List[VLChatProcessorOutput]
|
||||
) -> BatchedVLChatProcessorOutput:
|
||||
"""
|
||||
Preprocesses the inputs for multimodal inference.
|
||||
|
||||
Args:
|
||||
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||||
|
||||
Returns:
|
||||
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||||
"""
|
||||
|
||||
batch_size = len(prepare_list)
|
||||
sft_format = []
|
||||
n_images = []
|
||||
seq_lens = []
|
||||
for prepare in prepare_list:
|
||||
# we only fill the images for understanding tasks into the mask
|
||||
n_images.append(len(prepare.num_und_image_tokens))
|
||||
seq_lens.append(len(prepare))
|
||||
|
||||
input_token_max_len = max(seq_lens)
|
||||
max_n_images = max(1, max(n_images))
|
||||
|
||||
batched_input_ids = torch.full(
|
||||
(batch_size, input_token_max_len), self.pad_id
|
||||
).long() # FIXME
|
||||
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||||
batched_pixel_values = torch.zeros(
|
||||
(batch_size, max_n_images, *self.image_processor.default_shape)
|
||||
).float()
|
||||
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||||
batched_images_emb_mask = torch.zeros(
|
||||
(
|
||||
batch_size,
|
||||
max_n_images,
|
||||
self.num_image_tokens + 1,
|
||||
) # add 1 to account for <image_beg>
|
||||
).bool()
|
||||
|
||||
for i, prepare in enumerate(prepare_list):
|
||||
input_ids = prepare.input_ids
|
||||
seq_len = len(prepare)
|
||||
n_image = len(prepare.num_und_image_tokens)
|
||||
# left-padding
|
||||
batched_attention_mask[i, -seq_len:] = 1
|
||||
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||||
batched_images_seq_mask[i, -seq_len:] = (input_ids == self.image_id) | (
|
||||
input_ids == self.image_start_id
|
||||
)
|
||||
|
||||
if n_image > 0:
|
||||
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||||
for j, n_image_tokens in enumerate(prepare.num_und_image_tokens):
|
||||
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||||
|
||||
sft_format.append(prepare.sft_format)
|
||||
|
||||
batched_prepares = BatchedVLChatProcessorOutput(
|
||||
input_ids=batched_input_ids,
|
||||
attention_mask=batched_attention_mask,
|
||||
pixel_values=batched_pixel_values,
|
||||
images_seq_mask=batched_images_seq_mask,
|
||||
images_emb_mask=batched_images_emb_mask,
|
||||
sft_format=sft_format,
|
||||
)
|
||||
|
||||
return batched_prepares
|
691
janus/janusflow/models/siglip_vit.py
Normal file
691
janus/janusflow/models/siglip_vit.py
Normal file
@ -0,0 +1,691 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Final,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.layers import (
|
||||
AttentionPoolLatent,
|
||||
DropPath,
|
||||
LayerType,
|
||||
Mlp,
|
||||
PatchDropout,
|
||||
PatchEmbed,
|
||||
resample_abs_pos_embed,
|
||||
)
|
||||
from timm.models._manipulate import checkpoint_seq, named_apply
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std) # noqa: E741
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
||||
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
||||
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
|
||||
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
||||
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
dtype = tensor.dtype
|
||||
tensor_fp32 = tensor.float()
|
||||
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
||||
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
||||
tensor.copy_(tensor_dtype)
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, "init_weights"):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
# self.fused_attn = use_fused_attn()
|
||||
self.fused_attn = True
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
proj_drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: nn.Module = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.ls1 = (
|
||||
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = mlp_layer(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.ls2 = (
|
||||
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
|
||||
dynamic_img_size: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: Literal["", "avg", "token", "map"] = "token",
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
drop_rate: float = 0.0,
|
||||
pos_drop_rate: float = 0.0,
|
||||
patch_drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
block_fn: Type[nn.Module] = Block,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
ignore_head: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of image input channels.
|
||||
num_classes: Mumber of classes for classification head.
|
||||
global_pool: Type of global pooling for final sequence (default: 'token').
|
||||
embed_dim: Transformer embedding dimension.
|
||||
depth: Depth of transformer.
|
||||
num_heads: Number of attention heads.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: Enable bias for qkv projections if True.
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||
class_token: Use class token.
|
||||
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
||||
reg_tokens: Number of register tokens.
|
||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
embed_layer: Patch embedding layer.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
block_fn: Transformer block layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ("", "avg", "token", "map")
|
||||
assert class_token or global_pool != "token"
|
||||
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
||||
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
# act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = (
|
||||
embed_dim # num_features for consistency with other models
|
||||
)
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens += reg_tokens
|
||||
self.num_reg_tokens = reg_tokens
|
||||
self.has_class_token = class_token
|
||||
self.no_embed_class = (
|
||||
no_embed_class # don't embed prefix positions (includes reg)
|
||||
)
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
self.ignore_head = ignore_head
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_img_size:
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
dynamic_img_pad=dynamic_img_pad,
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = (
|
||||
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
)
|
||||
self.reg_token = (
|
||||
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
)
|
||||
embed_len = (
|
||||
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
)
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
patch_drop_rate,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
)
|
||||
else:
|
||||
self.patch_drop = nn.Identity()
|
||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
init_values=init_values,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
mlp_layer=mlp_layer,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
if global_pool == "map":
|
||||
AttentionPoolLatent.init_weights = init_weights
|
||||
self.attn_pool = AttentionPoolLatent(
|
||||
self.embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
if weight_init != "skip":
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
||||
assert mode in ("jax", "jax_nlhb", "moco", "")
|
||||
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> Set:
|
||||
return {"pos_embed", "cls_token", "dist_token"}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
||||
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ("", "avg", "token", "map")
|
||||
if global_pool == "map" and self.attn_pool is None:
|
||||
assert (
|
||||
False
|
||||
), "Cannot currently add attention pooling in reset_classifier()."
|
||||
elif global_pool != "map " and self.attn_pool is not None:
|
||||
self.attn_pool = None # remove attention pooling
|
||||
self.global_pool = global_pool
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
)
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||
)
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
to_cat = []
|
||||
if self.cls_token is not None:
|
||||
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||
if self.reg_token is not None:
|
||||
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
||||
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + pos_embed
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
x = x + pos_embed
|
||||
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(
|
||||
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
||||
)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||
Inspired by DINO / DINOv2 interface
|
||||
"""
|
||||
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
outputs = self._intermediate_layers(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
||||
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
grid_size = self.patch_embed.grid_size
|
||||
outputs = [
|
||||
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
if return_prefix_tokens:
|
||||
return tuple(zip(outputs, prefix_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.attn_pool is not None:
|
||||
x = self.attn_pool(x)
|
||||
elif self.global_pool == "avg":
|
||||
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
||||
elif self.global_pool:
|
||||
x = x[:, 0] # class token
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
if not self.ignore_head:
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SigLIPVisionCfg:
|
||||
width: int = 1152
|
||||
layers: Union[Tuple[int, int, int, int], int] = 27
|
||||
heads: int = 16
|
||||
patch_size: int = 14
|
||||
image_size: Union[Tuple[int, int], int] = 336
|
||||
global_pool: str = "map"
|
||||
mlp_ratio: float = 3.7362
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
|
||||
|
||||
SigLIP_MODEL_CONFIG = {
|
||||
"siglip_so400m_patch14_384": {
|
||||
"image_size": 336,
|
||||
"patch_size": 14,
|
||||
"width": 1152,
|
||||
"layers": 27,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 3.7362,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False,
|
||||
},
|
||||
"siglip_so400m_patch14_224": {
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"width": 1152,
|
||||
"layers": 27,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 3.7362,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False,
|
||||
},
|
||||
"siglip_large_patch16_384": {
|
||||
"image_size": 384,
|
||||
"patch_size": 16,
|
||||
"width": 1024,
|
||||
"layers": 24,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 4,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False,
|
||||
},
|
||||
"siglip_large_patch16_256": {
|
||||
"image_size": 256,
|
||||
"patch_size": 16,
|
||||
"width": 1024,
|
||||
"layers": 24,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 4,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_siglip_vit(
|
||||
model_name: str = "siglip_so400m_patch14_384",
|
||||
image_size: int = 384,
|
||||
select_layer: int = -1,
|
||||
ckpt_path: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
assert (
|
||||
model_name in SigLIP_MODEL_CONFIG.keys()
|
||||
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
||||
|
||||
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
||||
|
||||
if select_layer <= 0:
|
||||
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
||||
else:
|
||||
layers = min(vision_cfg.layers, select_layer)
|
||||
|
||||
model = VisionTransformer(
|
||||
img_size=image_size,
|
||||
patch_size=vision_cfg.patch_size,
|
||||
embed_dim=vision_cfg.width,
|
||||
depth=layers,
|
||||
num_heads=vision_cfg.heads,
|
||||
mlp_ratio=vision_cfg.mlp_ratio,
|
||||
class_token=vision_cfg.class_token,
|
||||
global_pool=vision_cfg.global_pool,
|
||||
ignore_head=kwargs.get("ignore_head", True),
|
||||
weight_init=kwargs.get("weight_init", "skip"),
|
||||
num_classes=0,
|
||||
)
|
||||
|
||||
if ckpt_path:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(
|
||||
f"SigLIP-ViT restores from {ckpt_path},\n"
|
||||
f"\tincompatible_keys:', {incompatible_keys}."
|
||||
)
|
||||
|
||||
return model
|
714
janus/janusflow/models/uvit.py
Normal file
714
janus/janusflow/models/uvit.py
Normal file
@ -0,0 +1,714 @@
|
||||
# Copyright (c) 2023-2024 DeepSeek.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
# this software and associated documentation files (the "Software"), to deal in
|
||||
# the Software without restriction, including without limitation the rights to
|
||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
# the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
# modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torchvision
|
||||
import torchvision.utils
|
||||
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm
|
||||
|
||||
|
||||
class ImageHead(nn.Module):
|
||||
|
||||
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
cfg = (
|
||||
AttrDict(
|
||||
norm_type="layernorm",
|
||||
is_exp_norm=False,
|
||||
sequence_parallel=False,
|
||||
use_userbuffer=False,
|
||||
norm_eps=1e-5,
|
||||
norm_bias=True,
|
||||
gradient_accumulation_fusion=True,
|
||||
use_fp32_head_weight=False,
|
||||
)
|
||||
+ gpt_cfg
|
||||
)
|
||||
group = PG.tensor_parallel_group()
|
||||
assert cfg.norm_type in [
|
||||
"layernorm",
|
||||
"rmsnorm",
|
||||
], f"Norm type:{cfg.norm_type} not supported"
|
||||
if cfg.norm_type == "rmsnorm":
|
||||
self.norm = DropoutAddRMSNorm(
|
||||
cfg.n_embed,
|
||||
prenorm=False,
|
||||
eps=cfg.norm_eps,
|
||||
is_exp_norm=cfg.is_exp_norm,
|
||||
sequence_parallel=cfg.sequence_parallel,
|
||||
)
|
||||
else:
|
||||
self.norm = DropoutAddLayerNorm(
|
||||
cfg.n_embed,
|
||||
prenorm=False,
|
||||
eps=cfg.norm_eps,
|
||||
is_exp_norm=cfg.is_exp_norm,
|
||||
sequence_parallel=cfg.sequence_parallel,
|
||||
bias=cfg.norm_bias,
|
||||
)
|
||||
|
||||
multiple_of = 256
|
||||
if decoder_cfg.in_channels % multiple_of != 0:
|
||||
warnings.warn(
|
||||
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
|
||||
)
|
||||
|
||||
dtype = default_dtype = torch.get_default_dtype()
|
||||
if cfg.use_fp32_head_weight:
|
||||
dtype = torch.float32
|
||||
print(
|
||||
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
torch.set_default_dtype(dtype)
|
||||
self.head = ColumnParallelLinear(
|
||||
cfg.n_embed,
|
||||
decoder_cfg.in_channels,
|
||||
bias=True,
|
||||
group=group,
|
||||
sequence_parallel=cfg.sequence_parallel,
|
||||
use_userbuffer=cfg.use_userbuffer,
|
||||
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
|
||||
use_fp32_output=False,
|
||||
)
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
self.use_fp32_head_weight = cfg.use_fp32_head_weight
|
||||
|
||||
def forward(
|
||||
self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs
|
||||
):
|
||||
residual = None
|
||||
if isinstance(input_args, tuple):
|
||||
x, residual = input_args
|
||||
else:
|
||||
x = input_args
|
||||
|
||||
x = self.norm(x, residual)
|
||||
|
||||
if self.use_fp32_head_weight:
|
||||
assert (
|
||||
self.head.weight.dtype == torch.float32
|
||||
), f"head.weight is {self.head.weight.dtype}"
|
||||
x = x.float()
|
||||
|
||||
if images_split_mask is None:
|
||||
logits = self.head(x)
|
||||
else:
|
||||
bs, n_images = images_split_mask.shape[:2]
|
||||
n_embed = x.shape[-1]
|
||||
|
||||
images_embed = torch.masked_select(
|
||||
x.unsqueeze(1), images_split_mask.unsqueeze(-1)
|
||||
)
|
||||
images_embed = images_embed.view((bs * n_images, -1, n_embed))
|
||||
logits = self.head(images_embed)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
|
||||
return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1)
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
norm_type=None,
|
||||
eps=None,
|
||||
elementwise_affine=None,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
self.name = name
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(channels, eps)
|
||||
elif norm_type is None:
|
||||
self.norm = None
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
if use_conv:
|
||||
conv = nn.Conv2d(
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
kernel_size: Optional[int] = None,
|
||||
padding=1,
|
||||
stride=2,
|
||||
norm_type=None,
|
||||
eps=None,
|
||||
elementwise_affine=None,
|
||||
bias=True,
|
||||
interpolate=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.interpolate = interpolate
|
||||
self.stride = stride
|
||||
|
||||
if norm_type == "ln_norm":
|
||||
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(channels, eps)
|
||||
elif norm_type is None:
|
||||
self.norm = None
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type: {norm_type}")
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
if kernel_size is None:
|
||||
kernel_size = 4
|
||||
conv = nn.ConvTranspose2d(
|
||||
channels,
|
||||
self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
elif use_conv:
|
||||
if kernel_size is None:
|
||||
kernel_size = 3
|
||||
conv = nn.Conv2d(
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.norm is not None:
|
||||
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if self.interpolate:
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(
|
||||
hidden_states, scale_factor=self.stride, mode="nearest"
|
||||
)
|
||||
else:
|
||||
hidden_states = F.interpolate(
|
||||
hidden_states, size=output_size, mode="nearest"
|
||||
)
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
norm_eps,
|
||||
elementwise_affine,
|
||||
use_bias,
|
||||
hidden_dropout,
|
||||
hidden_size,
|
||||
res_ffn_factor: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.depthwise = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=channels,
|
||||
bias=use_bias,
|
||||
)
|
||||
self.norm = RMSNorm(channels, norm_eps)
|
||||
self.channelwise_linear_1 = nn.Linear(
|
||||
channels, int(channels * res_ffn_factor), bias=use_bias
|
||||
)
|
||||
self.channelwise_act = nn.GELU()
|
||||
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
|
||||
self.channelwise_linear_2 = nn.Linear(
|
||||
int(channels * res_ffn_factor), channels, bias=use_bias
|
||||
)
|
||||
self.channelwise_dropout = nn.Dropout(hidden_dropout)
|
||||
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
|
||||
|
||||
def forward(self, x, cond_embeds):
|
||||
x_res = x
|
||||
|
||||
x = self.depthwise(x)
|
||||
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
x = self.channelwise_linear_1(x)
|
||||
x = self.channelwise_act(x)
|
||||
x = self.channelwise_norm(x)
|
||||
x = self.channelwise_linear_2(x)
|
||||
x = self.channelwise_dropout(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
x = x + x_res
|
||||
|
||||
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
|
||||
# x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
||||
x = torch.addcmul(
|
||||
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Patchify(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
block_out_channels,
|
||||
patch_size,
|
||||
bias,
|
||||
elementwise_affine,
|
||||
eps,
|
||||
kernel_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
if kernel_size is None:
|
||||
kernel_size = patch_size
|
||||
self.patch_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = RMSNorm(block_out_channels, eps)
|
||||
|
||||
def forward(self, x):
|
||||
embeddings = self.patch_conv(x)
|
||||
embeddings = embeddings.permute(0, 2, 3, 1)
|
||||
embeddings = self.norm(embeddings)
|
||||
embeddings = embeddings.permute(0, 3, 1, 2)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Unpatchify(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(in_channels, eps)
|
||||
self.unpatch_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * patch_size * patch_size,
|
||||
kernel_size=1,
|
||||
bias=bias,
|
||||
)
|
||||
self.pixel_shuffle = nn.PixelShuffle(patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
def forward(self, x):
|
||||
# [b, c, h, w]
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.unpatch_conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
return x
|
||||
|
||||
|
||||
class UVitBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
stride,
|
||||
hidden_size,
|
||||
hidden_dropout,
|
||||
elementwise_affine,
|
||||
norm_eps,
|
||||
use_bias,
|
||||
downsample: bool,
|
||||
upsample: bool,
|
||||
res_ffn_factor: int = 4,
|
||||
seq_len=None,
|
||||
concat_input=False,
|
||||
original_input_channels=None,
|
||||
use_zero=True,
|
||||
norm_type="RMS",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.res_blocks = nn.ModuleList()
|
||||
for i in range(num_res_blocks):
|
||||
conv_block = ConvNextBlock(
|
||||
channels,
|
||||
norm_eps,
|
||||
elementwise_affine,
|
||||
use_bias,
|
||||
hidden_dropout,
|
||||
hidden_size,
|
||||
res_ffn_factor=res_ffn_factor,
|
||||
)
|
||||
|
||||
self.res_blocks.append(conv_block)
|
||||
|
||||
if downsample:
|
||||
self.downsample = Downsample2D(
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
use_conv=True,
|
||||
name="Conv2d_0",
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
norm_type="rms_norm",
|
||||
eps=norm_eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
bias=use_bias,
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
if upsample:
|
||||
self.upsample = Upsample2D(
|
||||
channels=channels,
|
||||
out_channels=out_channels,
|
||||
use_conv_transpose=False,
|
||||
use_conv=True,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
name="conv",
|
||||
norm_type="rms_norm",
|
||||
eps=norm_eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
bias=use_bias,
|
||||
interpolate=True,
|
||||
)
|
||||
else:
|
||||
self.upsample = None
|
||||
|
||||
def forward(self, x, emb, recompute=False):
|
||||
for res_block in self.res_blocks:
|
||||
x = res_block(x, emb)
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.upsample is not None:
|
||||
x = self.upsample(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ShallowUViTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channels=3,
|
||||
stride=4,
|
||||
kernel_size=7,
|
||||
padding=None,
|
||||
block_out_channels=(768,),
|
||||
layers_in_middle=2,
|
||||
hidden_size=2048,
|
||||
elementwise_affine=True,
|
||||
use_bias=True,
|
||||
norm_eps=1e-6,
|
||||
dropout=0.0,
|
||||
use_mid_block=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
|
||||
)
|
||||
self.time_embed = TimestepEmbedding(
|
||||
block_out_channels[0], hidden_size, sample_proj_bias=use_bias
|
||||
)
|
||||
|
||||
if padding is None:
|
||||
padding = math.ceil(kernel_size - stride)
|
||||
self.in_conv = nn.Conv2d(
|
||||
in_channels=input_channels,
|
||||
out_channels=block_out_channels[0],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
if use_mid_block:
|
||||
self.mid_block = UVitBlock(
|
||||
block_out_channels[-1],
|
||||
block_out_channels[-1],
|
||||
num_res_blocks=layers_in_middle,
|
||||
hidden_size=hidden_size,
|
||||
hidden_dropout=dropout,
|
||||
elementwise_affine=elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
use_bias=use_bias,
|
||||
downsample=False,
|
||||
upsample=False,
|
||||
stride=1,
|
||||
res_ffn_factor=4,
|
||||
)
|
||||
else:
|
||||
self.mid_block = None
|
||||
|
||||
def get_num_extra_tensors(self):
|
||||
return 2
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
|
||||
bs = x.shape[0]
|
||||
dtype = x.dtype
|
||||
|
||||
t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype)
|
||||
t_emb = self.time_embed(t_emb)
|
||||
x_emb = self.in_conv(x)
|
||||
|
||||
if self.mid_block is not None:
|
||||
x_emb = self.mid_block(x_emb, t_emb)
|
||||
|
||||
hs = [x_emb]
|
||||
return x_emb, t_emb, hs
|
||||
|
||||
|
||||
class ShallowUViTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=768,
|
||||
out_channels=3,
|
||||
block_out_channels: Tuple[int] = (768,),
|
||||
upsamples=2,
|
||||
layers_in_middle=2,
|
||||
hidden_size=2048,
|
||||
elementwise_affine=True,
|
||||
norm_eps=1e-6,
|
||||
use_bias=True,
|
||||
dropout=0.0,
|
||||
use_mid_block=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if use_mid_block:
|
||||
self.mid_block = UVitBlock(
|
||||
in_channels + block_out_channels[-1],
|
||||
block_out_channels[
|
||||
-1
|
||||
], # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false.
|
||||
num_res_blocks=layers_in_middle,
|
||||
hidden_size=hidden_size,
|
||||
hidden_dropout=dropout,
|
||||
elementwise_affine=elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
use_bias=use_bias,
|
||||
downsample=False,
|
||||
upsample=False,
|
||||
stride=1,
|
||||
res_ffn_factor=4,
|
||||
)
|
||||
else:
|
||||
self.mid_block = None
|
||||
self.out_convs = nn.ModuleList()
|
||||
for rank in range(upsamples):
|
||||
if rank == upsamples - 1:
|
||||
curr_out_channels = out_channels
|
||||
else:
|
||||
curr_out_channels = block_out_channels[-1]
|
||||
if rank == 0:
|
||||
curr_in_channels = block_out_channels[-1] + in_channels
|
||||
else:
|
||||
curr_in_channels = block_out_channels[-1]
|
||||
self.out_convs.append(
|
||||
Unpatchify(
|
||||
curr_in_channels,
|
||||
curr_out_channels,
|
||||
patch_size=2,
|
||||
bias=use_bias,
|
||||
elementwise_affine=elementwise_affine,
|
||||
eps=norm_eps,
|
||||
)
|
||||
)
|
||||
self.input_norm = RMSNorm(in_channels, norm_eps)
|
||||
|
||||
def forward(self, x, hs, t_emb):
|
||||
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.input_norm(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
|
||||
x = torch.cat([x, hs.pop()], dim=1)
|
||||
if self.mid_block is not None:
|
||||
x = self.mid_block(x, t_emb)
|
||||
for out_conv in self.out_convs:
|
||||
x = out_conv(x)
|
||||
assert len(hs) == 0
|
||||
return x
|
Loading…
Reference in New Issue
Block a user