This commit is contained in:
Jarrod Lowe 2025-05-29 17:00:42 -07:00 committed by GitHub
commit 5bd68bcea5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 105 additions and 64 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.files import Files
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -296,6 +297,14 @@ class ChatTable:
.update({"share_id": shared_chat.id}) .update({"share_id": shared_chat.id})
) )
db.commit() db.commit()
# Make sure all the output files are shared. There names are GUIDs
# and they don't show up in search for everyone, so these can still
# only be accessed if you know their GUID and are a valid user.
for fileId in chat.meta.get("outputFileIds", []):
log.debug(f"Setting shared on file {fileId}")
Files.update_file_access_control_by_id(fileId, {"shared": True})
return shared_chat if (shared_result and result) else None return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
@ -951,5 +960,42 @@ class ChatTable:
except Exception: except Exception:
return False return False
def add_output_file_id_to_chat(self, id: str, file_id: str) -> Optional[ChatModel]:
"""Adds a new file ID to the outputFileIds list in the chat's metadata."""
try:
with get_db() as db:
chat = db.query(Chat).filter(Chat.id == id).with_for_update().first()
if chat is None:
return None
output_file_ids = chat.meta.get("outputFileIds", [])
if file_id in output_file_ids:
return ChatModel.model_validate(chat)
output_file_ids.append(file_id)
chat.meta = {**chat.meta, "outputFileIds": output_file_ids}
chat.updated_at = int(time.time())
db.commit()
return ChatModel.model_validate(chat)
except Exception as e:
log.error(f"Error adding output file ID: {e}")
return None
def get_output_file_ids_by_chat_id(self, id: str) -> list[str]:
"""
Gets all file IDs from the outputFileIds list in the chat's metadata.
"""
try:
with get_db() as db:
chat = db.query(Chat).filter(Chat.id == id).first()
if chat is None:
return []
return chat.meta.get("outputFileIds", [])
except Exception as e:
log.error(f"Error getting output file IDs: {e}")
return []
Chats = ChatTable() Chats = ChatTable()

View File

@ -211,6 +211,21 @@ class FilesTable:
except Exception: except Exception:
return None return None
def update_file_access_control_by_id(
self, id: str, access_control: dict
) -> Optional[FileModel]:
with get_db() as db:
try:
file = db.query(File).filter_by(id=id).first()
file.access_control = {
**(file.access_control if file.access_control else {}),
**access_control,
}
db.commit()
return FileModel.model_validate(file)
except Exception as e:
return None
def delete_file_by_id(self, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:

View File

@ -62,19 +62,19 @@ def has_access_to_file(
detail=ERROR_MESSAGES.NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND,
) )
has_access = False if file.access_control.get("shared", False):
knowledge_base_id = file.meta.get("collection_name") if file.meta else None return True
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id: if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type user.id, access_type
) )
for knowledge_base in knowledge_bases: for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id: if knowledge_base.id == knowledge_base_id:
has_access = True return True
break
return has_access return False
############################ ############################

View File

@ -3,6 +3,7 @@ import logging
import sys import sys
import os import os
import base64 import base64
import io
import asyncio import asyncio
from aiocache import cached from aiocache import cached
@ -18,7 +19,7 @@ from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from fastapi import Request, HTTPException from fastapi import Request, HTTPException, UploadFile
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
@ -41,6 +42,7 @@ from open_webui.routers.pipelines import (
process_pipeline_inlet_filter, process_pipeline_inlet_filter,
process_pipeline_outlet_filter, process_pipeline_outlet_filter,
) )
from open_webui.routers.files import upload_file
from open_webui.routers.memories import query_memory, QueryMemoryForm from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
@ -2157,7 +2159,9 @@ async def process_chat_response(
) )
retries += 1 retries += 1
log.debug(f"Attempt count: {retries}") log.debug(
f"Attempt count: {retries}, intepreter {request.app.state.config.CODE_INTERPRETER_ENGINE}"
)
output = "" output = ""
try: try:
@ -2206,73 +2210,49 @@ async def process_chat_response(
"stdout": "Code interpreter engine not configured." "stdout": "Code interpreter engine not configured."
} }
log.debug(f"Code interpreter output: {output}")
if isinstance(output, dict): if isinstance(output, dict):
stdout = output.get("stdout", "") for sourceField in ("stdout", "result"):
source = output.get(sourceField, "")
if isinstance(stdout, str): if isinstance(source, str):
stdoutLines = stdout.split("\n") sourceLines = source.split("\n")
for idx, line in enumerate(stdoutLines): for idx, line in enumerate(sourceLines):
if "data:image/png;base64" in line: if "data:image/png;base64" in line:
id = str(uuid4()) # line looks like data:image/png;base64,<base64data>
content_type = (
# ensure the path exists line.split(",")[0]
os.makedirs( .split(";")[0]
os.path.join(CACHE_DIR, "images"), .split(":")[1]
exist_ok=True, )
) file_data = io.BytesIO(
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode( base64.b64decode(
line.split(",")[1] line.split(",")[1]
) )
) )
file_name = f"image-{metadata['chat_id']}-{metadata['message_id']}-{sourceField}-{idx}.png"
stdoutLines[idx] = ( file = UploadFile(
f"![Output Image {idx}](/cache/images/{id}.png)" filename=file_name,
) file=file_data,
headers={
output["stdout"] = "\n".join(stdoutLines) "content-type": content_type
},
result = output.get("result", "") )
file_response = upload_file(
if isinstance(result, str): request, file, user=user
resultLines = result.split("\n") )
for idx, line in enumerate(resultLines): Chats.add_output_file_id_to_chat(
if "data:image/png;base64" in line: metadata["chat_id"],
id = str(uuid4()) file_response.id,
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
) )
resultLines[idx] = ( sourceLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)" f"![Output Image {idx}](/api/v1/files/{file_response.id}/content)"
) )
output[sourceField] = "\n".join(sourceLines)
output["result"] = "\n".join(resultLines)
except Exception as e: except Exception as e:
log.exception(e)
output = str(e) output = str(e)
content_blocks[-1]["output"] = output content_blocks[-1]["output"] = output