2024-12-13 12:38:59 +00:00
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from argparse import ArgumentParser
from typing import List , Dict
import torch
from transformers import AutoModelForCausalLM
import PIL . Image
2024-12-17 07:42:09 +00:00
from deepseek_vl2 . models import DeepseekVLV2ForCausalLM , DeepseekVLV2Processor
from deepseek_vl2 . serve . app_modules . utils import parse_ref_bbox
2024-12-13 12:38:59 +00:00
def load_pil_images ( conversations : List [ Dict [ str , str ] ] ) - > List [ PIL . Image . Image ] :
"""
Args :
conversations ( List [ Dict [ str , str ] ] ) : the conversations with a list of messages . An example is :
[
{
" role " : " User " ,
" content " : " <image> \n Extract all information from this image and convert them into markdown format. " ,
" images " : [ " ./examples/table_datasets.png " ]
} ,
{ " role " : " Assistant " , " content " : " " } ,
]
Returns :
pil_images ( List [ PIL . Image . Image ] ) : the list of PIL images .
"""
pil_images = [ ]
for message in conversations :
if " images " not in message :
continue
for image_path in message [ " images " ] :
pil_img = PIL . Image . open ( image_path )
pil_img = pil_img . convert ( " RGB " )
pil_images . append ( pil_img )
return pil_images
def main ( args ) :
dtype = torch . bfloat16
# specify the path to the model
model_path = args . model_path
vl_chat_processor : DeepseekVLV2Processor = DeepseekVLV2Processor . from_pretrained ( model_path )
tokenizer = vl_chat_processor . tokenizer
vl_gpt : DeepseekVLV2ForCausalLM = AutoModelForCausalLM . from_pretrained (
model_path ,
trust_remote_code = True ,
torch_dtype = dtype
)
vl_gpt = vl_gpt . cuda ( ) . eval ( )
# single image conversation example
conversation = [
{
" role " : " <|User|> " ,
2024-12-26 14:37:57 +00:00
" content " : " <image> \n <image> \n <|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image. " ,
" images " : [
" images/incontext_visual_grounding_1.jpeg " ,
" images/icl_vg_2.jpeg "
] ,
2024-12-13 12:38:59 +00:00
} ,
{ " role " : " <|Assistant|> " , " content " : " " } ,
]
2024-12-26 14:37:57 +00:00
# conversation = [
# {
# "role": "<|User|>",
# "content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
# "images": ["./images/visual_grounding_1.jpeg"],
# },
# {"role": "<|Assistant|>", "content": ""},
# ]
2024-12-13 12:38:59 +00:00
# load images and prepare for inputs
pil_images = load_pil_images ( conversation )
2024-12-26 14:37:57 +00:00
print ( f " len(pil_images) = { len ( pil_images ) } " )
# input_ids = batched_input_ids,
# attention_mask = batched_attention_mask,
# labels = batched_labels,
# images_tiles = batched_images,
# images_seq_mask = batched_images_seq_mask,
# images_spatial_crop = batched_images_spatial_crop,
# sft_format = batched_sft_format,
# seq_lens = seq_lens
2024-12-13 12:38:59 +00:00
prepare_inputs = vl_chat_processor . __call__ (
conversations = conversation ,
images = pil_images ,
force_batchify = True ,
system_prompt = " "
) . to ( vl_gpt . device , dtype = dtype )
2024-12-26 14:37:57 +00:00
# for key in prepare_inputs.keys():
# value = prepare_inputs[key]
# if isinstance(value, list):
# print(key, len(value), type(value))
# elif isinstance(value, torch.Tensor):
# print(key, value.shape, type(value))
2024-12-13 12:38:59 +00:00
with torch . no_grad ( ) :
2024-12-26 14:37:57 +00:00
2024-12-30 06:48:51 +00:00
if args . chunk_size == - 1 :
inputs_embeds = vl_gpt . prepare_inputs_embeds ( * * prepare_inputs )
past_key_values = None
else :
2024-12-30 06:19:34 +00:00
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds , past_key_values = vl_gpt . incremental_prefilling (
input_ids = prepare_inputs . input_ids ,
images = prepare_inputs . images ,
images_seq_mask = prepare_inputs . images_seq_mask ,
images_spatial_crop = prepare_inputs . images_spatial_crop ,
attention_mask = prepare_inputs . attention_mask ,
chunk_size = args . chunk_size
)
2024-12-13 12:38:59 +00:00
# run the model to get the response
outputs = vl_gpt . generate (
2024-12-26 14:37:57 +00:00
# inputs_embeds=inputs_embeds[:, -1:],
# input_ids=prepare_inputs.input_ids[:, -1:],
2024-12-13 12:38:59 +00:00
inputs_embeds = inputs_embeds ,
2024-12-26 14:37:57 +00:00
input_ids = prepare_inputs . input_ids ,
images = prepare_inputs . images ,
images_seq_mask = prepare_inputs . images_seq_mask ,
images_spatial_crop = prepare_inputs . images_spatial_crop ,
2024-12-13 12:38:59 +00:00
attention_mask = prepare_inputs . attention_mask ,
2024-12-26 14:37:57 +00:00
past_key_values = past_key_values ,
2024-12-13 12:38:59 +00:00
pad_token_id = tokenizer . eos_token_id ,
bos_token_id = tokenizer . bos_token_id ,
eos_token_id = tokenizer . eos_token_id ,
2024-12-26 14:37:57 +00:00
max_new_tokens = 512 ,
2024-12-13 12:38:59 +00:00
2024-12-26 14:37:57 +00:00
# do_sample=False,
2024-12-13 12:38:59 +00:00
# repetition_penalty=1.1,
2024-12-26 14:37:57 +00:00
do_sample = True ,
temperature = 0.4 ,
top_p = 0.9 ,
repetition_penalty = 1.1 ,
2024-12-13 12:38:59 +00:00
use_cache = True ,
)
2024-12-26 14:37:57 +00:00
answer = tokenizer . decode ( outputs [ 0 ] [ len ( prepare_inputs . input_ids [ 0 ] ) : ] . cpu ( ) . tolist ( ) , skip_special_tokens = False )
2024-12-13 12:38:59 +00:00
print ( f " { prepare_inputs [ ' sft_format ' ] [ 0 ] } " , answer )
2024-12-26 14:37:57 +00:00
vg_image = parse_ref_bbox ( answer , image = pil_images [ - 1 ] )
2024-12-13 12:38:59 +00:00
if vg_image is not None :
vg_image . save ( " ./vg.jpg " , format = " JPEG " , quality = 85 )
if __name__ == " __main__ " :
parser = ArgumentParser ( )
parser . add_argument ( " --model_path " , type = str , required = True ,
2024-12-26 14:37:57 +00:00
default = " deepseek-ai/deepseek-vl2 " ,
2024-12-13 12:38:59 +00:00
help = " model name or local path to the model " )
2024-12-30 06:19:34 +00:00
parser . add_argument ( " --chunk_size " , type = int , default = - 1 ,
help = " chunk size for the model for prefiiling. "
" When using 40G gpu for vl2-small, set a chunk_size for incremental_prefilling. "
" Otherwise, default value is -1, which means we do not use incremental_prefilling. " )
2024-12-13 12:38:59 +00:00
args = parser . parse_args ( )
main ( args )