diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 68465e191..a26c06c61 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -1,5 +1,6 @@ import asyncio import base64 +import io import json import logging import mimetypes @@ -7,10 +8,9 @@ import re import uuid from pathlib import Path from typing import Optional -import io import requests -from fastapi import APIRouter, Depends, UploadFile, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS @@ -392,8 +392,7 @@ def load_b64_image_data(b64_str): return None -def save_url_image(url, headers=None): - image_id = str(uuid.uuid4()) +def load_url_image_data(url, headers=None): try: if headers: r = requests.get(url, headers=headers) @@ -403,18 +402,7 @@ def save_url_image(url, headers=None): r.raise_for_status() if r.headers["content-type"].split("/")[0] == "image": mime_type = r.headers["content-type"] - image_format = mimetypes.guess_extension(mime_type) - - if not image_format: - raise ValueError("Could not determine image type from MIME type") - - image_filename = f"{image_id}{image_format}" - - file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}") - with open(file_path, "wb") as image_file: - for chunk in r.iter_content(chunk_size=8192): - image_file.write(chunk) - return image_filename + return r.content, mime_type else: log.error("Url does not point to an image.") return None @@ -486,8 +474,14 @@ async def image_generations( }, ) file_item = upload_file(request, file, user) - url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + url = request.app.url_path_for( + "get_file_content_by_id", id=file_item.id + ) images.append({"url": url}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json") + + with open(file_body_path, "w") as f: + json.dump(data, f) return images elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": @@ -533,9 +527,20 @@ async def image_generations( "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" } - image_filename = save_url_image(image["url"], headers) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") + image_data, content_type = load_url_image_data(image["url"], headers) + file = UploadFile( + file=io.BytesIO(image_data), + filename="image", # will be converted to a unique ID on upload_file + headers={ + "content-type": content_type, + }, + ) + file_item = upload_file(request, file, user) + url = request.app.url_path_for( + "get_file_content_by_id", id=file_item.id + ) + images.append({"url": url}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json") with open(file_body_path, "w") as f: json.dump(form_data.model_dump(exclude_none=True), f) @@ -585,9 +590,20 @@ async def image_generations( images = [] for image in res["images"]: - image_filename = save_b64_image(image) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") + image_data, content_type = load_b64_image_data(image) + file = UploadFile( + file=io.BytesIO(image_data), + filename="image", # will be converted to a unique ID on upload_file + headers={ + "content-type": content_type, + }, + ) + file_item = upload_file(request, file, user) + url = request.app.url_path_for( + "get_file_content_by_id", id=file_item.id + ) + images.append({"url": url}) + file_body_path = IMAGE_CACHE_DIR.joinpath(f"{file_item.id}.json") with open(file_body_path, "w") as f: json.dump({**data, "info": res["info"]}, f) @@ -599,4 +615,4 @@ async def image_generations( data = r.json() if "error" in data: error = data["error"]["message"] - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) \ No newline at end of file + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))