Merge pull request #9486 from rragundez/store-images

Use DB for generated images
This commit is contained in:
Timothy Jaeryang Baek 2025-02-11 21:12:17 -08:00 committed by GitHub
commit ab70f1bb50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 89 deletions

View File

@ -3,30 +3,22 @@ import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote from urllib.parse import quote
from open_webui.storage.provider import Storage from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import ( from open_webui.models.files import (
FileForm, FileForm,
FileModel, FileModel,
FileModelResponse, FileModelResponse,
Files, Files,
) )
from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.storage.provider import Storage
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -41,7 +33,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse) @router.post("/", response_model=FileModelResponse)
def upload_file( def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user) request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = {},
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
@ -65,6 +60,7 @@ def upload_file(
"name": name, "name": name,
"content_type": file.content_type, "content_type": file.content_type,
"size": len(contents), "size": len(contents),
"data": file_metadata,
}, },
} }
), ),
@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files() Storage.delete_all_files()
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error deleting files") log.error("Error deleting files")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error getting file content") log.error("Error getting file content")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error getting file content") log.error("Error getting file content")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"), detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path) Storage.delete_file(file.path)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
log.error(f"Error deleting files") log.error("Error deleting files")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"), detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),

View File

@ -1,32 +1,26 @@
import asyncio import asyncio
import base64 import base64
import io
import json import json
import logging import logging
import mimetypes import mimetypes
import re import re
import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import requests import requests
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from open_webui.config import CACHE_DIR from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import ( from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm, ComfyUIGenerateImageForm,
ComfyUIWorkflow, ComfyUIWorkflow,
comfyui_generate_image, comfyui_generate_image,
) )
from pydantic import BaseModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config( async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user) request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
): ):
set_image_model(request, form_data.MODEL) set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$" pattern = r"^\d+x\d+$"
@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None negative_prompt: Optional[str] = None
def save_b64_image(b64_str): def load_b64_image_data(b64_str):
try: try:
image_id = str(uuid.uuid4())
if "," in b64_str: if "," in b64_str:
header, encoded = b64_str.split(",", 1) header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0] mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded) img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else: else:
image_filename = f"{image_id}.png" mime_type = "image/png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
img_data = base64.b64decode(b64_str) img_data = base64.b64decode(b64_str)
return img_data, mime_type
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
except Exception as e: except Exception as e:
log.exception(f"Error saving image: {e}") log.exception(f"Error loading image data: {e}")
return None return None
def save_url_image(url, headers=None): def load_url_image_data(url, headers=None):
image_id = str(uuid.uuid4())
try: try:
if headers: if headers:
r = requests.get(url, headers=headers) r = requests.get(url, headers=headers)
@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image": if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"] mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type) return r.content, mime_type
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
else: else:
log.error("Url does not point to an image.") log.error("Url does not point to an image.")
return None return None
@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
return None return None
def upload_image(request, image_metadata, image_data, content_type, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations") @router.post("/generations")
async def image_generations( async def image_generations(
request: Request, request: Request,
@ -500,13 +478,9 @@ async def image_generations(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_filename = save_b64_image(image["b64_json"]) image_data, content_type = load_b64_image_data(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"}) url = upload_image(request, data, image_data, content_type, user)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") images.append({"url": url})
with open(file_body_path, "w") as f:
json.dump(data, f)
return images return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@ -552,14 +526,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}" "Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
} }
image_filename = save_url_image(image["url"], headers) image_data, content_type = load_url_image_data(image["url"], headers)
images.append({"url": f"/cache/image/generations/{image_filename}"}) url = upload_image(
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") request,
form_data.model_dump(exclude_none=True),
with open(file_body_path, "w") as f: image_data,
json.dump(form_data.model_dump(exclude_none=True), f) content_type,
user,
log.debug(f"images: {images}") )
images.append({"url": url})
return images return images
elif ( elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111" request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@ -604,13 +579,15 @@ async def image_generations(
images = [] images = []
for image in res["images"]: for image in res["images"]:
image_filename = save_b64_image(image) image_data, content_type = load_b64_image_data(image)
images.append({"url": f"/cache/image/generations/{image_filename}"}) url = upload_image(
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json") request,
{**data, "info": res["info"]},
with open(file_body_path, "w") as f: image_data,
json.dump({**data, "info": res["info"]}, f) content_type,
user,
)
images.append({"url": url})
return images return images
except Exception as e: except Exception as e:
error = e error = e