diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 7afd9d106..68465e191 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -7,26 +7,21 @@ import re import uuid from pathlib import Path from typing import Optional +import io import requests - - -from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel - - +from fastapi import APIRouter, Depends, UploadFile, HTTPException, Request from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS - +from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS +from open_webui.routers.files import upload_file from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, ComfyUIWorkflow, comfyui_generate_image, ) - +from pydantic import BaseModel log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["IMAGES"]) @@ -39,7 +34,7 @@ router = APIRouter() @router.get("/config") -async def get_config(request: Request, user=Depends(get_admin_user)): +async def get_def(request: Request, user=Depends(get_admin_user)): return { "enabled": request.app.state.config.ENABLE_IMAGE_GENERATION, "engine": request.app.state.config.IMAGE_GENERATION_ENGINE, @@ -271,7 +266,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)): async def update_image_config( request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) ): - set_image_model(request, form_data.MODEL) pattern = r"^\d+x\d+$" @@ -383,35 +377,18 @@ class GenerateImageForm(BaseModel): negative_prompt: Optional[str] = None -def save_b64_image(b64_str): +def load_b64_image_data(b64_str): try: - image_id = str(uuid.uuid4()) - if "," in b64_str: header, encoded = b64_str.split(",", 1) mime_type = header.split(";")[0] - img_data = base64.b64decode(encoded) - image_format = mimetypes.guess_extension(mime_type) - - image_filename = f"{image_id}{image_format}" - file_path = IMAGE_CACHE_DIR / f"{image_filename}" - with open(file_path, "wb") as f: - f.write(img_data) - return image_filename else: - image_filename = f"{image_id}.png" - file_path = IMAGE_CACHE_DIR.joinpath(image_filename) - + mime_type = "image/png" img_data = base64.b64decode(b64_str) - - # Write the image data to a file - with open(file_path, "wb") as f: - f.write(img_data) - return image_filename - + return img_data, mime_type except Exception as e: - log.exception(f"Error saving image: {e}") + log.exception(f"Error loading image data: {e}") return None @@ -500,13 +477,17 @@ async def image_generations( images = [] for image in res["data"]: - image_filename = save_b64_image(image["b64_json"]) - images.append({"url": f"/cache/image/generations/{image_filename}"}) - file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") - - with open(file_body_path, "w") as f: - json.dump(data, f) - + image_data, content_type = load_b64_image_data(image["b64_json"]) + file = UploadFile( + file=io.BytesIO(image_data), + filename="image", # will be converted to a unique ID on upload_file + headers={ + "content-type": content_type, + }, + ) + file_item = upload_file(request, file, user) + url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + images.append({"url": url}) return images elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": @@ -618,4 +599,4 @@ async def image_generations( data = r.json() if "error" in data: error = data["error"]["message"] - raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error)) \ No newline at end of file