Merge pull request #9486 from rragundez/store-images

Use DB for generated images
This commit is contained in:
Timothy Jaeryang Baek 2025-02-11 21:12:17 -08:00 committed by GitHub
commit ab70f1bb50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 89 deletions

View File

@ -3,30 +3,22 @@ import os
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel
import mimetypes
from urllib.parse import quote
from open_webui.storage.provider import Storage
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.config import UPLOAD_DIR
from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -41,7 +33,10 @@ router = APIRouter()
@router.post("/", response_model=FileModelResponse)
def upload_file(
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
try:
@ -65,6 +60,7 @@ def upload_file(
"name": name,
"content_type": file.content_type,
"size": len(contents),
"data": file_metadata,
},
}
),
@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
Storage.delete_all_files()
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
except Exception as e:
log.exception(e)
log.error(f"Error getting file content")
log.error("Error getting file content")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
Storage.delete_file(file.path)
except Exception as e:
log.exception(e)
log.error(f"Error deleting files")
log.error("Error deleting files")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),

View File

@ -1,32 +1,26 @@
import asyncio
import base64
import io
import json
import logging
import mimetypes
import re
import uuid
from pathlib import Path
from typing import Optional
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
from open_webui.routers.files import upload_file
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.images.comfyui import (
ComfyUIGenerateImageForm,
ComfyUIWorkflow,
comfyui_generate_image,
)
from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
async def update_image_config(
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
):
set_image_model(request, form_data.MODEL)
pattern = r"^\d+x\d+$"
@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
def load_b64_image_data(b64_str):
try:
image_id = str(uuid.uuid4())
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
else:
image_filename = f"{image_id}.png"
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
mime_type = "image/png"
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_filename
return img_data, mime_type
except Exception as e:
log.exception(f"Error saving image: {e}")
log.exception(f"Error loading image data: {e}")
return None
def save_url_image(url, headers=None):
image_id = str(uuid.uuid4())
def load_url_image_data(url, headers=None):
try:
if headers:
r = requests.get(url, headers=headers)
@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_filename
return r.content, mime_type
else:
log.error("Url does not point to an image.")
return None
@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
filename=f"generated{image_format}", # will be converted to a unique ID on upload_file
headers={
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@router.post("/generations")
async def image_generations(
request: Request,
@ -500,13 +478,9 @@ async def image_generations(
images = []
for image in res["data"]:
image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
@ -552,14 +526,15 @@ async def image_generations(
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
image_filename = save_url_image(image["url"], headers)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump(form_data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image(
request,
form_data.model_dump(exclude_none=True),
image_data,
content_type,
user,
)
images.append({"url": url})
return images
elif (
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
@ -604,13 +579,15 @@ async def image_generations(
images = []
for image in res["images"]:
image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
image_data, content_type = load_b64_image_data(image)
url = upload_image(
request,
{**data, "info": res["info"]},
image_data,
content_type,
user,
)
images.append({"url": url})
return images
except Exception as e:
error = e