mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge pull request #9486 from rragundez/store-images
Use DB for generated images
This commit is contained in:
commit
ab70f1bb50
@ -3,30 +3,22 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel
|
|
||||||
import mimetypes
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from open_webui.storage.provider import Storage
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.models.files import (
|
from open_webui.models.files import (
|
||||||
FileForm,
|
FileForm,
|
||||||
FileModel,
|
FileModel,
|
||||||
FileModelResponse,
|
FileModelResponse,
|
||||||
Files,
|
Files,
|
||||||
)
|
)
|
||||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
from open_webui.routers.retrieval import ProcessFileForm, process_file
|
||||||
|
from open_webui.storage.provider import Storage
|
||||||
from open_webui.config import UPLOAD_DIR
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
@ -41,7 +33,10 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.post("/", response_model=FileModelResponse)
|
@router.post("/", response_model=FileModelResponse)
|
||||||
def upload_file(
|
def upload_file(
|
||||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
request: Request,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
file_metadata: dict = {},
|
||||||
):
|
):
|
||||||
log.info(f"file.content_type: {file.content_type}")
|
log.info(f"file.content_type: {file.content_type}")
|
||||||
try:
|
try:
|
||||||
@ -65,6 +60,7 @@ def upload_file(
|
|||||||
"name": name,
|
"name": name,
|
||||||
"content_type": file.content_type,
|
"content_type": file.content_type,
|
||||||
"size": len(contents),
|
"size": len(contents),
|
||||||
|
"data": file_metadata,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -126,7 +122,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
|
|||||||
Storage.delete_all_files()
|
Storage.delete_all_files()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error deleting files")
|
log.error("Error deleting files")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||||
@ -248,7 +244,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error getting file content")
|
log.error("Error getting file content")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||||
@ -279,7 +275,7 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error getting file content")
|
log.error("Error getting file content")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
detail=ERROR_MESSAGES.DEFAULT("Error getting file content"),
|
||||||
@ -355,7 +351,7 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
|||||||
Storage.delete_file(file.path)
|
Storage.delete_file(file.path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
log.error(f"Error deleting files")
|
log.error("Error deleting files")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||||
|
@ -1,32 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import CACHE_DIR
|
from open_webui.config import CACHE_DIR
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS
|
||||||
|
from open_webui.routers.files import upload_file
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.images.comfyui import (
|
from open_webui.utils.images.comfyui import (
|
||||||
ComfyUIGenerateImageForm,
|
ComfyUIGenerateImageForm,
|
||||||
ComfyUIWorkflow,
|
ComfyUIWorkflow,
|
||||||
comfyui_generate_image,
|
comfyui_generate_image,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||||
@ -271,7 +265,6 @@ async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
|||||||
async def update_image_config(
|
async def update_image_config(
|
||||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
set_image_model(request, form_data.MODEL)
|
set_image_model(request, form_data.MODEL)
|
||||||
|
|
||||||
pattern = r"^\d+x\d+$"
|
pattern = r"^\d+x\d+$"
|
||||||
@ -383,40 +376,22 @@ class GenerateImageForm(BaseModel):
|
|||||||
negative_prompt: Optional[str] = None
|
negative_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def save_b64_image(b64_str):
|
def load_b64_image_data(b64_str):
|
||||||
try:
|
try:
|
||||||
image_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if "," in b64_str:
|
if "," in b64_str:
|
||||||
header, encoded = b64_str.split(",", 1)
|
header, encoded = b64_str.split(",", 1)
|
||||||
mime_type = header.split(";")[0]
|
mime_type = header.split(";")[0]
|
||||||
|
|
||||||
img_data = base64.b64decode(encoded)
|
img_data = base64.b64decode(encoded)
|
||||||
image_format = mimetypes.guess_extension(mime_type)
|
|
||||||
|
|
||||||
image_filename = f"{image_id}{image_format}"
|
|
||||||
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(img_data)
|
|
||||||
return image_filename
|
|
||||||
else:
|
else:
|
||||||
image_filename = f"{image_id}.png"
|
mime_type = "image/png"
|
||||||
file_path = IMAGE_CACHE_DIR.joinpath(image_filename)
|
|
||||||
|
|
||||||
img_data = base64.b64decode(b64_str)
|
img_data = base64.b64decode(b64_str)
|
||||||
|
return img_data, mime_type
|
||||||
# Write the image data to a file
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(img_data)
|
|
||||||
return image_filename
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error saving image: {e}")
|
log.exception(f"Error loading image data: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def save_url_image(url, headers=None):
|
def load_url_image_data(url, headers=None):
|
||||||
image_id = str(uuid.uuid4())
|
|
||||||
try:
|
try:
|
||||||
if headers:
|
if headers:
|
||||||
r = requests.get(url, headers=headers)
|
r = requests.get(url, headers=headers)
|
||||||
@ -426,18 +401,7 @@ def save_url_image(url, headers=None):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
if r.headers["content-type"].split("/")[0] == "image":
|
if r.headers["content-type"].split("/")[0] == "image":
|
||||||
mime_type = r.headers["content-type"]
|
mime_type = r.headers["content-type"]
|
||||||
image_format = mimetypes.guess_extension(mime_type)
|
return r.content, mime_type
|
||||||
|
|
||||||
if not image_format:
|
|
||||||
raise ValueError("Could not determine image type from MIME type")
|
|
||||||
|
|
||||||
image_filename = f"{image_id}{image_format}"
|
|
||||||
|
|
||||||
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
|
|
||||||
with open(file_path, "wb") as image_file:
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
|
||||||
image_file.write(chunk)
|
|
||||||
return image_filename
|
|
||||||
else:
|
else:
|
||||||
log.error("Url does not point to an image.")
|
log.error("Url does not point to an image.")
|
||||||
return None
|
return None
|
||||||
@ -447,6 +411,20 @@ def save_url_image(url, headers=None):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||||
|
image_format = mimetypes.guess_extension(content_type)
|
||||||
|
file = UploadFile(
|
||||||
|
file=io.BytesIO(image_data),
|
||||||
|
filename=f"generated{image_format}", # will be converted to a unique ID on upload_file
|
||||||
|
headers={
|
||||||
|
"content-type": content_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||||
|
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/generations")
|
@router.post("/generations")
|
||||||
async def image_generations(
|
async def image_generations(
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -500,13 +478,9 @@ async def image_generations(
|
|||||||
images = []
|
images = []
|
||||||
|
|
||||||
for image in res["data"]:
|
for image in res["data"]:
|
||||||
image_filename = save_b64_image(image["b64_json"])
|
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
url = upload_image(request, data, image_data, content_type, user)
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
images.append({"url": url})
|
||||||
|
|
||||||
with open(file_body_path, "w") as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||||
@ -552,14 +526,15 @@ async def image_generations(
|
|||||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||||
}
|
}
|
||||||
|
|
||||||
image_filename = save_url_image(image["url"], headers)
|
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
url = upload_image(
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
request,
|
||||||
|
form_data.model_dump(exclude_none=True),
|
||||||
with open(file_body_path, "w") as f:
|
image_data,
|
||||||
json.dump(form_data.model_dump(exclude_none=True), f)
|
content_type,
|
||||||
|
user,
|
||||||
log.debug(f"images: {images}")
|
)
|
||||||
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
elif (
|
elif (
|
||||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||||
@ -604,13 +579,15 @@ async def image_generations(
|
|||||||
images = []
|
images = []
|
||||||
|
|
||||||
for image in res["images"]:
|
for image in res["images"]:
|
||||||
image_filename = save_b64_image(image)
|
image_data, content_type = load_b64_image_data(image)
|
||||||
images.append({"url": f"/cache/image/generations/{image_filename}"})
|
url = upload_image(
|
||||||
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
|
request,
|
||||||
|
{**data, "info": res["info"]},
|
||||||
with open(file_body_path, "w") as f:
|
image_data,
|
||||||
json.dump({**data, "info": res["info"]}, f)
|
content_type,
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
images.append({"url": url})
|
||||||
return images
|
return images
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = e
|
error = e
|
||||||
|
Loading…
Reference in New Issue
Block a user