mirror of
https://github.com/deepseek-ai/DeepSeek-VL
synced 2025-06-26 18:27:43 +00:00
chore: make format
This commit is contained in:
64
cli_chat.py
64
cli_chat.py
@@ -3,9 +3,10 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from deepseek_vl.utils.io import load_pretrained_model
|
||||
@@ -33,22 +34,19 @@ def get_help_message(image_token):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
|
||||
def response(
|
||||
args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config
|
||||
):
|
||||
prompt = conv.get_prompt()
|
||||
prepare_inputs = vl_chat_processor.__call__(
|
||||
prompt=prompt,
|
||||
images=pil_images,
|
||||
force_batchify=True
|
||||
prompt=prompt, images=pil_images, force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True
|
||||
tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
generation_config["inputs_embeds"] = inputs_embeds
|
||||
generation_config["attention_mask"] = prepare_inputs.attention_mask
|
||||
@@ -79,7 +77,6 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
help_msg = get_help_message(image_token)
|
||||
|
||||
while True:
|
||||
|
||||
print(help_msg)
|
||||
|
||||
pil_images = []
|
||||
@@ -87,9 +84,10 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
roles = conv.roles
|
||||
|
||||
while True:
|
||||
|
||||
# get user input
|
||||
user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")
|
||||
user_input = get_user_input(
|
||||
f"{roles[0]} [{image_token} indicates an image]: "
|
||||
)
|
||||
|
||||
if user_input == "exit":
|
||||
print("Chat program exited.")
|
||||
@@ -115,7 +113,9 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
|
||||
while cur_img_idx < num_images:
|
||||
try:
|
||||
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
|
||||
image_file = input(
|
||||
f"({cur_img_idx + 1}/{num_images}) Input the image file path: "
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
@@ -134,11 +134,21 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
print(f"File error, `{image_file}` does not exist. Please input the correct file path.")
|
||||
print(
|
||||
f"File error, `{image_file}` does not exist. Please input the correct file path."
|
||||
)
|
||||
|
||||
# get the answer by the model's prediction
|
||||
answer = ""
|
||||
answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
|
||||
answer_iter = response(
|
||||
args,
|
||||
conv,
|
||||
pil_images,
|
||||
tokenizer,
|
||||
vl_chat_processor,
|
||||
vl_gpt,
|
||||
generation_config,
|
||||
)
|
||||
sys.stdout.write(f"{conv.roles[1]}: ")
|
||||
for char in answer_iter:
|
||||
answer += char
|
||||
@@ -152,7 +162,6 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# setup
|
||||
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
|
||||
generation_config = dict(
|
||||
@@ -163,12 +172,14 @@ def main(args):
|
||||
use_cache=True,
|
||||
)
|
||||
if args.temperature > 0:
|
||||
generation_config.update({
|
||||
"do_sample": True,
|
||||
"top_p": args.top_p,
|
||||
"temperature": args.temperature,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
})
|
||||
generation_config.update(
|
||||
{
|
||||
"do_sample": True,
|
||||
"top_p": args.top_p,
|
||||
"temperature": args.temperature,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}
|
||||
)
|
||||
else:
|
||||
generation_config.update({"do_sample": False})
|
||||
|
||||
@@ -177,12 +188,15 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
|
||||
help="the huggingface model name or the local path of the downloaded huggingface model.")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="deepseek-ai/deepseek-vl-7b-chat",
|
||||
help="the huggingface model name or the local path of the downloaded huggingface model.",
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=0.2)
|
||||
parser.add_argument("--top_p", type=float, default=0.95)
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.1)
|
||||
parser.add_argument("--max_gen_len", type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user