chore: make format

This commit is contained in:
Bo Liu
2024-03-13 14:39:46 +08:00
parent 90a18501d7
commit 48cc0deea6
13 changed files with 345 additions and 113 deletions

View File

@@ -3,9 +3,10 @@
import argparse
import os
import sys
from PIL import Image
from threading import Thread
import torch
from PIL import Image
from transformers import TextIteratorStreamer
from deepseek_vl.utils.io import load_pretrained_model
@@ -33,22 +34,19 @@ def get_help_message(image_token):
@torch.inference_mode()
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
def response(
args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config
):
prompt = conv.get_prompt()
prepare_inputs = vl_chat_processor.__call__(
prompt=prompt,
images=pil_images,
force_batchify=True
prompt=prompt, images=pil_images, force_batchify=True
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True
tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_config["inputs_embeds"] = inputs_embeds
generation_config["attention_mask"] = prepare_inputs.attention_mask
@@ -79,7 +77,6 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
help_msg = get_help_message(image_token)
while True:
print(help_msg)
pil_images = []
@@ -87,9 +84,10 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
roles = conv.roles
while True:
# get user input
user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")
user_input = get_user_input(
f"{roles[0]} [{image_token} indicates an image]: "
)
if user_input == "exit":
print("Chat program exited.")
@@ -115,7 +113,9 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
while cur_img_idx < num_images:
try:
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
image_file = input(
f"({cur_img_idx + 1}/{num_images}) Input the image file path: "
)
except KeyboardInterrupt:
print()
@@ -134,11 +134,21 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
sys.exit(0)
else:
print(f"File error, `{image_file}` does not exist. Please input the correct file path.")
print(
f"File error, `{image_file}` does not exist. Please input the correct file path."
)
# get the answer by the model's prediction
answer = ""
answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
answer_iter = response(
args,
conv,
pil_images,
tokenizer,
vl_chat_processor,
vl_gpt,
generation_config,
)
sys.stdout.write(f"{conv.roles[1]}: ")
for char in answer_iter:
answer += char
@@ -152,7 +162,6 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
def main(args):
# setup
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
generation_config = dict(
@@ -163,12 +172,14 @@ def main(args):
use_cache=True,
)
if args.temperature > 0:
generation_config.update({
"do_sample": True,
"top_p": args.top_p,
"temperature": args.temperature,
"repetition_penalty": args.repetition_penalty,
})
generation_config.update(
{
"do_sample": True,
"top_p": args.top_p,
"temperature": args.temperature,
"repetition_penalty": args.repetition_penalty,
}
)
else:
generation_config.update({"do_sample": False})
@@ -177,12 +188,15 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
help="the huggingface model name or the local path of the downloaded huggingface model.")
parser.add_argument(
"--model_path",
type=str,
default="deepseek-ai/deepseek-vl-7b-chat",
help="the huggingface model name or the local path of the downloaded huggingface model.",
)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1.1)
parser.add_argument("--max_gen_len", type=int, default=512)
args = parser.parse_args()
main(args)