mirror of
https://github.com/deepseek-ai/Janus
synced 2024-12-26 13:52:57 +00:00
116 lines
4.4 KiB
Python
116 lines
4.4 KiB
Python
# Copyright (c) 2023-2024 DeepSeek.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
# this software and associated documentation files (the "Software"), to deal in
|
|
# the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
|
import numpy as np
|
|
import os
|
|
import PIL.Image
|
|
|
|
# specify the path to the model
|
|
model_path = "deepseek-ai/Janus-1.3B"
|
|
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
|
tokenizer = vl_chat_processor.tokenizer
|
|
|
|
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
|
model_path, trust_remote_code=True
|
|
)
|
|
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
|
|
|
conversation = [
|
|
{
|
|
"role": "User",
|
|
"content": "A close-up high-contrast photo of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of roiling energy, exploding yellow stars, and radiating swirls of blue.",
|
|
},
|
|
{"role": "Assistant", "content": ""},
|
|
]
|
|
|
|
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
|
conversations=conversation,
|
|
sft_format=vl_chat_processor.sft_format,
|
|
system_prompt="",
|
|
)
|
|
prompt = sft_format + vl_chat_processor.image_start_tag
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
mmgpt: MultiModalityCausalLM,
|
|
vl_chat_processor: VLChatProcessor,
|
|
prompt: str,
|
|
temperature: float = 1,
|
|
parallel_size: int = 16,
|
|
cfg_weight: float = 5,
|
|
image_token_num_per_image: int = 576,
|
|
img_size: int = 384,
|
|
patch_size: int = 16,
|
|
):
|
|
input_ids = vl_chat_processor.tokenizer.encode(prompt)
|
|
input_ids = torch.LongTensor(input_ids)
|
|
|
|
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
|
|
for i in range(parallel_size*2):
|
|
tokens[i, :] = input_ids
|
|
if i % 2 != 0:
|
|
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
|
|
|
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
|
|
|
|
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
|
|
|
|
for i in range(image_token_num_per_image):
|
|
outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
|
|
hidden_states = outputs.last_hidden_state
|
|
|
|
logits = mmgpt.gen_head(hidden_states[:, -1, :])
|
|
logit_cond = logits[0::2, :]
|
|
logit_uncond = logits[1::2, :]
|
|
|
|
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
|
|
probs = torch.softmax(logits / temperature, dim=-1)
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
|
|
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
|
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
|
|
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
|
|
|
|
|
dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
|
|
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
|
|
|
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
|
|
|
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
|
visual_img[:, :, :] = dec
|
|
|
|
os.makedirs('generated_samples', exist_ok=True)
|
|
for i in range(parallel_size):
|
|
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
|
|
PIL.Image.fromarray(visual_img[i]).save(save_path)
|
|
|
|
|
|
generate(
|
|
vl_gpt,
|
|
vl_chat_processor,
|
|
prompt,
|
|
) |