feat: add base64 images support (#51)

This commit is contained in:
Youho99 2024-04-24 07:01:06 +02:00 committed by GitHub
parent 37fcec4806
commit 681bffb451
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,8 @@ from typing import Dict, List
import PIL.Image import PIL.Image
import torch import torch
import base64
import io
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
@ -42,6 +44,8 @@ def load_pretrained_model(model_path: str):
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
""" """
Support file path or base64 images.
Args: Args:
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
[ [
@ -64,8 +68,15 @@ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image
if "images" not in message: if "images" not in message:
continue continue
for image_path in message["images"]: for image_data in message["images"]:
pil_img = PIL.Image.open(image_path) if image_data.startswith("data:image"):
# Image data is in base64 format
_, image_data = image_data.split(",", 1)
image_bytes = base64.b64decode(image_data)
pil_img = PIL.Image.open(io.BytesIO(image_bytes))
else:
# Image data is a file path
pil_img = PIL.Image.open(image_data)
pil_img = pil_img.convert("RGB") pil_img = pil_img.convert("RGB")
pil_images.append(pil_img) pil_images.append(pil_img)