DeepSeek-VL/cli_chat.py
2024-03-13 16:16:32 +08:00

190 lines
6.2 KiB
Python

# -*- coding: utf-8 -*-
import argparse
import os
import sys
from PIL import Image
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from deepseek_vl.utils.io import load_pretrained_model
def load_image(image_file):
image = Image.open(image_file).convert("RGB")
return image
def get_help_message(image_token):
help_msg = (
f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n"
f"Usage: \n"
f" 1. type `exit` to quit. \n"
f" 2. type `{image_token}` to indicate there is an image. You can enter multiple images, "
f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. "
f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n"
f" 4. type `help` to get the help messages. \n"
f" 5. type `new` to start a new conversation. \n"
f" Here is an example, you can type: '<image_placeholder>Describe the image.'\n"
)
return help_msg
@torch.inference_mode()
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
prompt = conv.get_prompt()
prepare_inputs = vl_chat_processor.__call__(
prompt=prompt,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_config["inputs_embeds"] = inputs_embeds
generation_config["attention_mask"] = prepare_inputs.attention_mask
generation_config["streamer"] = streamer
thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
thread.start()
yield from streamer
def get_user_input(hint: str):
user_input = ""
while user_input == "":
try:
user_input = input(f"{hint}")
except KeyboardInterrupt:
print()
continue
except EOFError:
user_input = "exit"
return user_input
def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
image_token = vl_chat_processor.image_token
help_msg = get_help_message(image_token)
while True:
print(help_msg)
pil_images = []
conv = vl_chat_processor.new_chat_template()
roles = conv.roles
while True:
# get user input
user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")
if user_input == "exit":
print("Chat program exited.")
sys.exit(0)
elif user_input == "help":
print(help_msg)
elif user_input == "new":
os.system("clear")
pil_images = []
conv = vl_chat_processor.new_chat_template()
torch.cuda.empty_cache()
print("New conversation started.")
else:
conv.append_message(conv.roles[0], user_input)
conv.append_message(conv.roles[1], None)
# check if the user input is an image token
num_images = user_input.count(image_token)
cur_img_idx = 0
while cur_img_idx < num_images:
try:
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
image_file = image_file.strip() # trim whitespaces around path, enables drop-in from for example Dolphin
except KeyboardInterrupt:
print()
continue
except EOFError:
image_file = None
if image_file and os.path.exists(image_file):
pil_image = load_image(image_file)
pil_images.append(pil_image)
cur_img_idx += 1
elif image_file == "exit":
print("Chat program exited.")
sys.exit(0)
else:
print(f"File error, `{image_file}` does not exist. Please input the correct file path.")
# get the answer by the model's prediction
answer = ""
answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
sys.stdout.write(f"{conv.roles[1]}: ")
for char in answer_iter:
answer += char
sys.stdout.write(char)
sys.stdout.flush()
sys.stdout.write("\n")
sys.stdout.flush()
conv.update_last_message(answer)
# conv.messages[-1][-1] = answer
def main(args):
# setup
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
generation_config = dict(
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
max_new_tokens=args.max_gen_len,
use_cache=True,
)
if args.temperature > 0:
generation_config.update({
"do_sample": True,
"top_p": args.top_p,
"temperature": args.temperature,
"repetition_penalty": args.repetition_penalty,
})
else:
generation_config.update({"do_sample": False})
chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
help="the huggingface model name or the local path of the downloaded huggingface model.")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1.1)
parser.add_argument("--max_gen_len", type=int, default=512)
args = parser.parse_args()
main(args)