diff --git a/deepseek_vl/utils/io.py b/deepseek_vl/utils/io.py index 081f7a2..b71432a 100644 --- a/deepseek_vl/utils/io.py +++ b/deepseek_vl/utils/io.py @@ -22,6 +22,8 @@ from typing import Dict, List import PIL.Image import torch +import base64 +import io from transformers import AutoModelForCausalLM from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor @@ -42,6 +44,8 @@ def load_pretrained_model(model_path: str): def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: """ + Support file path or base64 images. + Args: conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : [ @@ -64,8 +68,15 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image if "images" not in message: continue - for image_path in message["images"]: - pil_img = PIL.Image.open(image_path) + for image_data in message["images"]: + if image_data.startswith("data:image"): + # Image data is in base64 format + _, image_data = image_data.split(",", 1) + image_bytes = base64.b64decode(image_data) + pil_img = PIL.Image.open(io.BytesIO(image_bytes)) + else: + # Image data is a file path + pil_img = PIL.Image.open(image_data) pil_img = pil_img.convert("RGB") pil_images.append(pil_img)