Merge branch 'open-webui:main' into main

This commit is contained in:
Henry
2025-06-04 15:37:33 +02:00
committed by GitHub
478 changed files with 64194 additions and 22041 deletions

View File

@@ -73,8 +73,15 @@ def serve(
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
uvicorn.run(
"open_webui.main:app",
host=host,
port=port,
forwarded_allow_ips="*",
workers=UVICORN_WORKERS,
)
@app.command()

File diff suppressed because it is too large Load Diff

View File

@@ -31,6 +31,7 @@ class ERROR_MESSAGES(str, Enum):
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."

View File

@@ -65,10 +65,8 @@ except Exception:
# LOGGING
####################################
log_levels = ["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
GLOBAL_LOG_LEVEL = os.environ.get("GLOBAL_LOG_LEVEL", "").upper()
if GLOBAL_LOG_LEVEL in log_levels:
if GLOBAL_LOG_LEVEL in logging.getLevelNamesMapping():
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL, force=True)
else:
GLOBAL_LOG_LEVEL = "INFO"
@@ -78,6 +76,7 @@ log.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")
if "cuda_error" in locals():
log.exception(cuda_error)
del cuda_error
log_sources = [
"AUDIO",
@@ -100,19 +99,19 @@ SRC_LOG_LEVELS = {}
for source in log_sources:
log_env_var = source + "_LOG_LEVEL"
SRC_LOG_LEVELS[source] = os.environ.get(log_env_var, "").upper()
if SRC_LOG_LEVELS[source] not in log_levels:
if SRC_LOG_LEVELS[source] not in logging.getLevelNamesMapping():
SRC_LOG_LEVELS[source] = GLOBAL_LOG_LEVEL
log.info(f"{log_env_var}: {SRC_LOG_LEVELS[source]}")
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
####################################
# ENV (dev,test,prod)
@@ -130,7 +129,6 @@ else:
except Exception:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
@@ -161,7 +159,6 @@ try:
except Exception:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
@@ -192,7 +189,6 @@ for version in soup.find_all("h2"):
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
@@ -209,7 +205,6 @@ ENABLE_FORWARD_USER_INFO_HEADERS = (
os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true"
)
####################################
# WEBUI_BUILD_HASH
####################################
@@ -244,7 +239,6 @@ if FROM_INIT_PY:
DATA_DIR = Path(os.getenv("DATA_DIR", OPEN_WEBUI_DIR / "data"))
STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static"))
FONTS_DIR = Path(os.getenv("FONTS_DIR", OPEN_WEBUI_DIR / "static" / "fonts"))
@@ -256,7 +250,6 @@ if FROM_INIT_PY:
os.getenv("FRONTEND_BUILD_DIR", OPEN_WEBUI_DIR / "frontend")
).resolve()
####################################
# Database
####################################
@@ -321,7 +314,6 @@ RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
ENABLE_REALTIME_CHAT_SAVE = (
os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true"
)
@@ -330,7 +322,23 @@ ENABLE_REALTIME_CHAT_SAVE = (
# REDIS
####################################
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
####################################
# UVICORN WORKERS
####################################
# Number of uvicorn worker processes for handling requests
UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1")
try:
UVICORN_WORKERS = int(UVICORN_WORKERS)
if UVICORN_WORKERS < 1:
UVICORN_WORKERS = 1
except ValueError:
UVICORN_WORKERS = 1
log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}")
####################################
# WEBUI_AUTH (Required for security)
@@ -341,11 +349,19 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
)
BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
)
WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get(
"WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None
)
####################################
# WEBUI_SECRET_KEY
####################################
@@ -385,6 +401,11 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
@@ -396,19 +417,88 @@ else:
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
AIOHTTP_CLIENT_SESSION_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true"
)
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = None
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
)
if AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST == "":
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = int(
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = int(AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10"
)
if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "":
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None
else:
try:
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
)
except Exception:
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
)
####################################
# SENTENCE TRANSFORMERS
####################################
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
if SENTENCE_TRANSFORMERS_BACKEND == "":
SENTENCE_TRANSFORMERS_BACKEND = "torch"
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
####################################
# OFFLINE_MODE
@@ -418,3 +508,56 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
# Where to store log file
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
# Maximum size of a file before rotating into a new log file
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
# METADATA | REQUEST | REQUEST_RESPONSE
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper()
try:
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
except ValueError:
MAX_BODY_LOG_SIZE = 2048
# Comma separated list for urls to exclude from audit
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
","
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# OPENTELEMETRY
####################################
ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true"
OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get(
"OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"
)
OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui")
OTEL_RESOURCE_ATTRIBUTES = os.environ.get(
"OTEL_RESOURCE_ATTRIBUTES", ""
) # e.g. key1=val1,key2=val2
OTEL_TRACES_SAMPLER = os.environ.get(
"OTEL_TRACES_SAMPLER", "parentbased_always_on"
).lower()
####################################
# TOOLS/FUNCTIONS PIP OPTIONS
####################################
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
####################################
# PROGRESSIVE WEB APP OPTIONS
####################################
EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")

View File

@@ -2,6 +2,7 @@ import logging
import sys
import inspect
import json
import asyncio
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Iterator
@@ -27,7 +28,10 @@ from open_webui.socket.main import (
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
@@ -52,12 +56,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
@@ -76,11 +75,13 @@ async def get_function_models(request):
if hasattr(function_module, "pipes"):
sub_pipes = []
# Check if pipes is a function or a list
# Handle pipes being a list, sync function, or async function
try:
if callable(function_module.pipes):
sub_pipes = function_module.pipes()
if asyncio.iscoroutinefunction(function_module.pipes):
sub_pipes = await function_module.pipes()
else:
sub_pipes = function_module.pipes()
else:
sub_pipes = function_module.pipes
except Exception as e:
@@ -220,6 +221,9 @@ async def generate_function_chat_completion(
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__chat_id__": metadata.get("chat_id", None),
"__session_id__": metadata.get("session_id", None),
"__message_id__": metadata.get("message_id", None),
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
@@ -249,8 +253,13 @@ async def generate_function_chat_completion(
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
if params:
system = params.pop("system", None)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(
system, form_data, metadata, user
)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

View File

@@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url):
db = connect(db_url, unquote_password=True)
db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
@@ -51,7 +51,7 @@ def register_connection(db_url):
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_password=True)
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,33 @@
"""Add note table
Revision ID: 9f0c9cd09105
Revises: 3781e22d8b01
Create Date: 2025-05-03 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "9f0c9cd09105"
down_revision = "3781e22d8b01"
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"note",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("note")

View File

@@ -129,12 +129,16 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email)
if not user:
return None
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
return user
else:
return None
@@ -155,8 +159,8 @@ class AuthsTable:
except Exception:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}")
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()

View File

@@ -1,3 +1,4 @@
import logging
import json
import time
import uuid
@@ -5,7 +6,7 @@ from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
@@ -16,6 +17,9 @@ from sqlalchemy.sql import exists
# Chat DB Schema
####################
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Chat(Base):
__tablename__ = "chat"
@@ -373,22 +377,47 @@ class ChatTable:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
@@ -397,7 +426,23 @@ class ChatTable:
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
@@ -432,7 +477,7 @@ class ChatTable:
all_chats = query.all()
# result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
# result has to be destructured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
return [
ChatTitleIdResponse.model_validate(
{
@@ -538,7 +583,9 @@ class ChatTable:
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit
)
search_text_words = search_text.split(" ")
@@ -670,7 +717,7 @@ class ChatTable:
# Perform pagination at the SQL level
all_chats = query.offset(skip).limit(limit).all()
print(len(all_chats))
log.info(f"The number of chats: {len(all_chats)}")
# Validate and return chats
return [ChatModel.model_validate(chat) for chat in all_chats]
@@ -731,7 +778,7 @@ class ChatTable:
query = db.query(Chat).filter_by(user_id=user_id)
tag_id = tag_name.replace(" ", "_").lower()
print(db.bind.dialect.name)
log.info(f"DB dialect name: {db.bind.dialect.name}")
if db.bind.dialect.name == "sqlite":
# SQLite JSON1 querying for tags within the meta JSON field
query = query.filter(
@@ -752,7 +799,7 @@ class ChatTable:
)
all_chats = query.all()
print("all_chats", all_chats)
log.debug(f"all_chats: {all_chats}")
return [ChatModel.model_validate(chat) for chat in all_chats]
def add_chat_tag_by_id_and_user_id_and_tag_name(
@@ -810,7 +857,7 @@ class ChatTable:
count = query.count()
# Debugging output for inspection
print(f"Count of chats for tag '{tag_name}':", count)
log.info(f"Count of chats for tag '{tag_name}': {count}")
return count

View File

@@ -118,7 +118,7 @@ class FeedbackTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error creating a new feedback: {e}")
return None
def get_feedback_by_id(self, id: str) -> Optional[FeedbackModel]:

View File

@@ -119,7 +119,7 @@ class FilesTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error inserting a new file: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:

View File

@@ -9,6 +9,8 @@ from open_webui.models.chats import Chats
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from open_webui.utils.access_control import get_permissions
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -82,7 +84,7 @@ class FolderTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new folder: {e}")
return None
def get_folder_by_id_and_user_id(
@@ -234,15 +236,18 @@ class FolderTable:
log.error(f"update_folder: {e}")
return
def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> bool:
def delete_folder_by_id_and_user_id(
self, id: str, user_id: str, delete_chats=True
) -> bool:
try:
with get_db() as db:
folder = db.query(Folder).filter_by(id=id, user_id=user_id).first()
if not folder:
return False
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
if delete_chats:
# Delete all chats in the folder
Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id)
# Delete all children folders
def delete_children(folder):
@@ -250,9 +255,11 @@ class FolderTable:
folder.id, user_id
)
for folder_child in folder_children:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
if delete_chats:
Chats.delete_chats_by_user_id_and_folder_id(
user_id, folder_child.id
)
delete_children(folder_child)
folder = db.query(Folder).filter_by(id=folder_child.id).first()

View File

@@ -105,9 +105,57 @@ class FunctionsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new function: {e}")
return None
def sync_functions(
self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try:
with get_db() as db:
# Get existing functions
existing_functions = db.query(Function).all()
existing_ids = {func.id for func in existing_functions}
# Prepare a set of new function IDs
new_function_ids = {func.id for func in functions}
# Update or insert functions
for func in functions:
if func.id in existing_ids:
db.query(Function).filter_by(id=func.id).update(
{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_func = Function(
**{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_func)
# Remove functions that are no longer present
for func in existing_functions:
if func.id not in new_function_ids:
db.delete(func)
db.commit()
return [
FunctionModel.model_validate(func)
for func in db.query(Function).all()
]
except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}")
return []
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
with get_db() as db:
@@ -170,7 +218,7 @@ class FunctionsTable:
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting function valves by id {id}: {e}")
return None
def update_function_valves_by_id(
@@ -202,7 +250,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@@ -225,7 +275,9 @@ class FunctionsTable:
return user_settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:

View File

@@ -207,5 +207,43 @@ class GroupTable:
except Exception:
return False
def sync_user_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> bool:
with get_db() as db:
try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
# Add user to new groups
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception as e:
log.exception(e)
return False
Groups = GroupTable()

View File

@@ -63,14 +63,15 @@ class MemoriesTable:
else:
return None
def update_memory_by_id(
def update_memory_by_id_and_user_id(
self,
id: str,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
try:
db.query(Memory).filter_by(id=id).update(
db.query(Memory).filter_by(id=id, user_id=user_id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()

5
backend/open_webui/models/models.py Normal file → Executable file
View File

@@ -166,7 +166,7 @@ class ModelsTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Failed to insert a new model: {e}")
return None
def get_all_models(self) -> list[ModelModel]:
@@ -246,8 +246,7 @@ class ModelsTable:
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
log.exception(f"Failed to update the model by id {id}: {e}")
return None
def delete_model_by_id(self, id: str) -> bool:

View File

@@ -0,0 +1,135 @@
import json
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from open_webui.models.users import Users, UserResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Note DB Schema
####################
class Note(Base):
__tablename__ = "note"
id = Column(Text, primary_key=True)
user_id = Column(Text)
title = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
access_control = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class NoteModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class NoteForm(BaseModel):
title: str
data: Optional[dict] = None
meta: Optional[dict] = None
access_control: Optional[dict] = None
class NoteUserResponse(NoteModel):
user: Optional[UserResponse] = None
class NoteTable:
def insert_new_note(
self,
form_data: NoteForm,
user_id: str,
) -> Optional[NoteModel]:
with get_db() as db:
note = NoteModel(
**{
"id": str(uuid.uuid4()),
"user_id": user_id,
**form_data.model_dump(),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
}
)
new_note = Note(**note.model_dump())
db.add(new_note)
db.commit()
return note
def get_notes(self) -> list[NoteModel]:
with get_db() as db:
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[NoteModel]:
notes = self.get_notes()
return [
note
for note in notes
if note.user_id == user_id
or has_access(user_id, permission, note.access_control)
]
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None
def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
note.title = form_data.title
note.data = form_data.data
note.meta = form_data.meta
note.access_control = form_data.access_control
note.updated_at = int(time.time_ns())
db.commit()
return NoteModel.model_validate(note) if note else None
def delete_note_by_id(self, id: str):
with get_db() as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True
Notes = NoteTable()

View File

@@ -61,7 +61,7 @@ class TagTable:
else:
return None
except Exception as e:
print(e)
log.exception(f"Error inserting a new tag: {e}")
return None
def get_tag_by_name_and_user_id(

View File

@@ -131,7 +131,7 @@ class ToolsTable:
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
log.exception(f"Error creating a new tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
@@ -175,7 +175,7 @@ class ToolsTable:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
@@ -204,7 +204,9 @@ class ToolsTable:
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None
def update_user_valves_by_id_and_user_id(
@@ -227,7 +229,9 @@ class ToolsTable:
return user_settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:

View File

@@ -10,6 +10,8 @@ from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import or_
####################
# User DB Schema
@@ -67,6 +69,11 @@ class UserModel(BaseModel):
####################
class UserListResponse(BaseModel):
users: list[UserModel]
total: int
class UserResponse(BaseModel):
id: str
name: str
@@ -160,11 +167,63 @@ class UsersTable:
return None
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
self,
filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> UserListResponse:
with get_db() as db:
query = db.query(User)
query = db.query(User).order_by(User.created_at.desc())
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
)
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "email":
if direction == "asc":
query = query.order_by(User.email.asc())
else:
query = query.order_by(User.email.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(User.created_at.asc())
else:
query = query.order_by(User.created_at.desc())
elif order_by == "last_active_at":
if direction == "asc":
query = query.order_by(User.last_active_at.asc())
else:
query = query.order_by(User.last_active_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(User.updated_at.asc())
else:
query = query.order_by(User.updated_at.desc())
elif order_by == "role":
if direction == "asc":
query = query.order_by(User.role.asc())
else:
query = query.order_by(User.role.desc())
else:
query = query.order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
@@ -172,8 +231,10 @@ class UsersTable:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
return {
"users": [UserModel.model_validate(user) for user in users],
"total": db.query(User).count(),
}
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
@@ -330,5 +391,13 @@ class UsersTable:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users]
def get_super_admin_user(self) -> Optional[UserModel]:
with get_db() as db:
user = db.query(User).filter_by(role="admin").first()
if user:
return UserModel.model_validate(user)
else:
return None
Users = UsersTable()

View File

@@ -0,0 +1,251 @@
import os
import time
import requests
import logging
import json
from typing import List, Optional
from langchain_core.documents import Document
from fastapi import HTTPException, status
log = logging.getLogger(__name__)
class DatalabMarkerLoader:
def __init__(
self,
file_path: str,
api_key: str,
langs: Optional[str] = None,
use_llm: bool = False,
skip_cache: bool = False,
force_ocr: bool = False,
paginate: bool = False,
strip_existing_ocr: bool = False,
disable_image_extraction: bool = False,
output_format: str = None,
):
self.file_path = file_path
self.api_key = api_key
self.langs = langs
self.use_llm = use_llm
self.skip_cache = skip_cache
self.force_ocr = force_ocr
self.paginate = paginate
self.strip_existing_ocr = strip_existing_ocr
self.disable_image_extraction = disable_image_extraction
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower()
mime_map = {
"pdf": "application/pdf",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"odt": "application/vnd.oasis.opendocument.text",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"odp": "application/vnd.oasis.opendocument.presentation",
"html": "text/html",
"epub": "application/epub+zip",
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
"gif": "image/gif",
"tiff": "image/tiff",
}
return mime_map.get(ext, "application/octet-stream")
def check_marker_request_status(self, request_id: str) -> dict:
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
headers = {"X-Api-Key": self.api_key}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}")
return result
except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}",
)
except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
def load(self) -> List[Document]:
url = "https://www.datalab.to/api/v1/marker"
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
form_data = {
"langs": self.langs,
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"output_format": self.output_format,
}
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
try:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
response = requests.post(
url, data=form_data, files=files, headers=headers
)
response.raise_for_status()
result = response.json()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}",
)
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"):
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
)
check_url = result.get("request_check_url")
request_id = result.get("request_id")
if not check_url:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
)
for _ in range(300): # Up to 10 minutes
time.sleep(2)
try:
poll_response = requests.get(check_url, headers=headers)
poll_response.raise_for_status()
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
)
else:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
)
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
if content_key == "json":
full_text = json.dumps(raw_content, indent=2)
elif content_key in {"markdown", "html"}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}",
)
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Datalab Marker returned empty content",
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(content_key, "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
output_path = os.path.join(marker_output_dir, output_filename)
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(full_text)
log.info(f"Saved Marker output to: {output_path}")
except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}")
metadata = {
"source": filename,
"output_format": poll_result.get("output_format", self.output_format),
"page_count": poll_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
}
images = poll_result.get("images", {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))
for k, v in metadata.items():
if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v)
elif v is None:
metadata[k] = ""
return [Document(page_content=full_text, metadata=metadata)]

View File

@@ -0,0 +1,58 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalDocumentLoader(BaseLoader):
def __init__(
self,
file_path,
url: str,
api_key: str,
mime_type=None,
**kwargs,
) -> None:
self.url = url
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
headers = {}
if self.mime_type is not None:
headers["Content-Type"] = self.mime_type
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
url = self.url
if url.endswith("/"):
url = url[:-1]
r = requests.put(f"{url}/process", data=data, headers=headers)
if r.ok:
res = r.json()
if res:
return [
Document(
page_content=res.get("page_content"),
metadata=res.get("metadata"),
)
]
else:
raise Exception("Error loading document: No content returned")
else:
raise Exception(f"Error loading document: {r.status_code} {r.text}")

View File

@@ -0,0 +1,53 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalWebLoader(BaseLoader):
def __init__(
self,
web_paths: Union[str, List[str]],
external_url: str,
external_api_key: str,
continue_on_failure: bool = True,
**kwargs,
) -> None:
self.external_url = external_url
self.external_api_key = external_api_key
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
self.continue_on_failure = continue_on_failure
def lazy_load(self) -> Iterator[Document]:
batch_size = 20
for i in range(0, len(self.urls), batch_size):
urls = self.urls[i : i + batch_size]
try:
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
"Authorization": f"Bearer {self.external_api_key}",
},
json={
"urls": urls,
},
)
response.raise_for_status()
results = response.json()
for result in results:
yield Document(
page_content=result.get("page_content", ""),
metadata=result.get("metadata", {}),
)
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}")
else:
raise e

View File

@@ -4,6 +4,7 @@ import ftfy
import sys
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
BSHTMLLoader,
CSVLoader,
Docx2txtLoader,
@@ -19,6 +20,13 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
@@ -76,15 +84,18 @@ known_source_ext = [
"jsx",
"hs",
"lhs",
"json",
]
class TikaLoader:
def __init__(self, url, file_path, mime_type=None):
def __init__(self, url, file_path, mime_type=None, extract_images=None):
self.url = url
self.file_path = file_path
self.mime_type = mime_type
self.extract_images = extract_images
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
@@ -94,6 +105,9 @@ class TikaLoader:
else:
headers = {}
if self.extract_images == True:
headers["X-Tika-PDFextractInlineImages"] = "true"
endpoint = self.url
if not endpoint.endswith("/"):
endpoint += "/"
@@ -103,7 +117,7 @@ class TikaLoader:
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
@@ -115,6 +129,68 @@ class TikaLoader:
raise Exception(f"Error calling Tika: {r.reason}")
class DoclingLoader:
def __init__(self, url, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.file_path = file_path
self.mime_type = mime_type
self.params = params or {}
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
files = {
"files": (
self.file_path,
f,
self.mime_type or "application/octet-stream",
)
}
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
lang.strip()
for lang in self.params.get("ocr_lang").split(",")
if lang.strip()
]
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
if r.ok:
result = r.json()
document_data = result.get("document", {})
text = document_data.get("md_content", "<No text content found>")
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
if r.text:
try:
error_data = r.json()
if "detail" in error_data:
error_msg += f" - {error_data['detail']}"
except Exception:
error_msg += f" - {r.text}"
raise Exception(f"Error calling Docling: {error_msg}")
class Loader:
def __init__(self, engine: str = "", **kwargs):
self.engine = engine
@@ -133,27 +209,140 @@ class Loader:
for doc in docs
]
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
if (
self.engine == "external"
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
):
loader = ExternalDocumentLoader(
file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
mime_type=file_content_type,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
)
elif (
self.engine == "datalab_marker"
and self.kwargs.get("DATALAB_MARKER_API_KEY")
and file_ext
in [
"pdf",
"xls",
"xlsx",
"ods",
"doc",
"docx",
"odt",
"ppt",
"pptx",
"odp",
"html",
"epub",
"png",
"jpeg",
"jpg",
"webp",
"gif",
"tiff",
]
):
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
strip_existing_ocr=self.kwargs.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
),
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
and (
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
or file_content_type
in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]
)
):
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
)
elif (
self.engine == "mistral_ocr"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
elif (
self.engine == "external"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml":
@@ -182,9 +371,7 @@ class Loader:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TextLoader(file_path, autodetect_encoding=True)

View File

@@ -0,0 +1,633 @@
import requests
import aiohttp
import asyncio
import logging
import os
import sys
import time
from typing import List, Dict, Any
from contextlib import asynccontextmanager
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
"""
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__(
self,
api_key: str,
file_path: str,
timeout: int = 300, # 5 minutes default
max_retries: int = 3,
enable_debug_logging: bool = False,
):
"""
Initializes the loader with enhanced features.
Args:
api_key: Your Mistral API key.
file_path: The local path to the PDF file to process.
timeout: Request timeout in seconds.
max_retries: Maximum number of retry attempts.
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}")
self.api_key = api_key
self.file_path = file_path
self.timeout = timeout
self.max_retries = max_retries
self.debug = enable_debug_logging
# Pre-compute file info for performance
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0",
}
def _debug_log(self, message: str, *args) -> None:
"""Conditional debug logging for performance."""
if self.debug:
log.debug(message, *args)
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
"""Checks response status and returns JSON content."""
try:
response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx)
# Handle potential empty responses for certain successful requests (e.g., DELETE)
if response.status_code == 204 or not response.content:
return {} # Return empty dict if no content
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
raise
except requests.exceptions.RequestException as req_err:
log.error(f"Request exception occurred: {req_err}")
raise
except ValueError as json_err: # Includes JSONDecodeError
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
raise # Re-raise after logging
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
response.raise_for_status()
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
raise
def _retry_request_sync(self, request_func, *args, **kwargs):
"""Synchronous retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""Async retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
def _upload_file(self) -> str:
"""Uploads the file to Mistral for OCR processing (sync version)."""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
def upload_request():
with open(self.file_path, "rb") as f:
files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
timeout=self.timeout,
)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
except Exception as e:
log.error(f"Failed to upload file: {e}")
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.BASE_API_URL}/files"
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
async with session.post(
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.timeout
)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
log.info("Signed URL received.")
return signed_url
except Exception as e:
log.error(f"Failed to get signed URL: {e}")
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval."""
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
async with session.get(
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
self._debug_log("Signed URL received successfully")
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.BASE_API_URL}/ocr"
ocr_headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.timeout
)
return self._handle_response(response)
try:
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
self._debug_log("OCR response: %s", ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.BASE_API_URL}/ocr"
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
start_time = time.time()
async with session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
return ocr_response
return await self._retry_request_async(ocr_request)
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}"
try:
response = requests.delete(url, headers=self.headers, timeout=30)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete(
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=30
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
@asynccontextmanager
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=10, # Total connection limit
limit_per_host=5, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
async with aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found", metadata={"error": "no_pages"}
)
]
documents = []
total_pages = len(pages_data)
skipped_pages = 0
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No valid text content found in document",
metadata={"error": "no_valid_pages", "total_pages": total_pages},
)
]
return documents
def load(self) -> List[Document]:
"""
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
Synchronous version for backward compatibility.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
# 1. Upload file
file_id = self._upload_file()
# 2. Get Signed URL
signed_url = self._get_signed_url(file_id)
# 3. Process OCR
ocr_response = self._process_ocr(signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Delete file (attempt even if prior steps failed after upload)
if file_id:
try:
self._delete_file(file_id)
except Exception as del_e:
# Log deletion error, but don't overwrite original error if one occurred
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
async def load_async(self) -> List[Document]:
"""
Asynchronous OCR workflow execution with optimized performance.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
async with self._get_session() as session:
# 1. Upload file with streaming
file_id = await self._upload_file_async(session)
# 2. Get signed URL
signed_url = await self._get_signed_url_async(session, file_id)
# 3. Process OCR
ocr_response = await self._process_ocr_async(session, signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
return [
Document(
page_content=f"Error during OCR processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Cleanup - always attempt file deletion
if file_id:
try:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
) -> List[List[Document]]:
"""
Process multiple files concurrently for maximum performance.
Args:
loaders: List of MistralLoader instances
Returns:
List of document lists, one for each loader
"""
if not loaders:
return []
log.info(f"Starting concurrent processing of {len(loaders)} files")
start_time = time.time()
# Process all files concurrently
tasks = [loader.load_async() for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
metadata={
"error": "batch_processing_failed",
"file_index": i,
},
)
]
)
else:
processed_results.append(result)
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
log.info(
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
)
return processed_results

View File

@@ -0,0 +1,93 @@
import requests
import logging
from typing import Iterator, List, Literal, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class TavilyLoader(BaseLoader):
"""Extract web page content from URLs using Tavily Extract API.
This is a LangChain document loader that uses Tavily's Extract API to
retrieve content from web pages and return it as Document objects.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
extract_depth: Depth of extraction, either "basic" or "advanced".
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
def __init__(
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
Args:
urls: URL or list of URLs to extract content from.
api_key: The Tavily API key.
include_images: Whether to include images in the extraction.
extract_depth: Depth of extraction, either "basic" or "advanced".
advanced extraction retrieves more data, including tables and
embedded content, with higher success but may increase latency.
basic costs 1 credit per 5 successful URL extractions,
advanced costs 2 credits per 5 successful URL extractions.
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
batch_size = 20
for i in range(0, len(self.urls), batch_size):
batch_urls = self.urls[i : i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
# Make the API call
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
if not content:
log.warning(f"No content extracted from {url}")
continue
# Add URLs as metadata
metadata = {"source": url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
else:
raise e

View File

@@ -62,12 +62,17 @@ class YoutubeLoader:
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self.language = language
self.proxy_url = proxy_url
# Ensure language is a list
if isinstance(language, str):
self.language = [language]
else:
self.language = language
self.language = list(language)
# Add English as fallback if not already in the list
if "en" not in self.language:
self.language.append("en")
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
@@ -101,17 +106,31 @@ class YoutubeLoader:
log.exception("Loading YouTube transcript failed")
return []
try:
transcript = transcript_list.find_transcript(self.language)
except NoTranscriptFound:
transcript = transcript_list.find_transcript(["en"])
# Try each language in order of priority
for lang in self.language:
try:
transcript = transcript_list.find_transcript([lang])
log.debug(f"Found transcript for language '{lang}'")
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript_text = " ".join(
map(
lambda transcript_piece: transcript_piece.text.strip(" "),
transcript_pieces,
)
)
return [Document(page_content=transcript_text, metadata=self._metadata)]
except NoTranscriptFound:
log.debug(f"No transcript found for language '{lang}'")
continue
except Exception as e:
log.info(f"Error finding transcript for language '{lang}'")
raise e
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
transcript = " ".join(
map(
lambda transcript_piece: transcript_piece["text"].strip(" "),
transcript_pieces,
)
# If we get here, all languages failed
languages_tried = ", ".join(self.language)
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
)
return [Document(page_content=transcript, metadata=self._metadata)]

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple
class BaseReranker(ABC):
@abstractmethod
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
pass

View File

@@ -1,13 +1,21 @@
import os
import logging
import torch
import numpy as np
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
class ColBERT:
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None:
print("ColBERT: Loading model", name)
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
DOCKER = kwargs.get("env") == "docker"

View File

@@ -0,0 +1,60 @@
import logging
import requests
from typing import Optional, List, Tuple
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,
url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker",
):
self.api_key = api_key
self.url = url
self.model = model
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
payload = {
"model": self.model,
"query": query,
"documents": docs,
"top_n": len(docs),
}
try:
log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}")
r = requests.post(
f"{self.url}",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
},
json=payload,
)
r.raise_for_status()
data = r.json()
if "results" in data:
sorted_results = sorted(data["results"], key=lambda x: x["index"])
return [result["relevance_score"] for result in sorted_results]
else:
log.error("No results found in external reranking response")
return None
except Exception as e:
log.exception(f"Error in external reranking: {e}")
return None

View File

@@ -1,27 +1,36 @@
import logging
import os
import uuid
from typing import Optional, Union
import asyncio
import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.retrieval.vector.main import GetResult
from open_webui.env import (
SRC_LOG_LEVELS,
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from open_webui.config import (
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -46,7 +55,7 @@ class VectorSearchRetriever(BaseRetriever):
) -> list[Document]:
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query)],
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
limit=self.top_k,
)
@@ -69,6 +78,7 @@ def query_doc(
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
log.debug(f"query_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
@@ -80,24 +90,40 @@ def query_doc(
return result
except Exception as e:
print(e)
log.exception(f"Error querying doc {collection_name} with limit {k}: {e}")
raise e
def get_doc(collection_name: str, user: UserModel = None):
try:
log.debug(f"get_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
log.info(f"query_doc:result {result.ids} {result.metadatas}")
return result
except Exception as e:
log.exception(f"Error getting doc {collection_name}: {e}")
raise e
def query_doc_with_hybrid_search(
collection_name: str,
collection_result: GetResult,
query: str,
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
bm25_retriever = BM25Retriever.from_texts(
texts=result.documents[0],
metadatas=result.metadatas[0],
texts=collection_result.documents[0],
metadatas=collection_result.metadatas[0],
)
bm25_retriever.k = k
@@ -107,12 +133,23 @@ def query_doc_with_hybrid_search(
top_k=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.0]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k,
top_n=k_reranker,
reranking_function=reranking_function,
r_score=r,
)
@@ -122,10 +159,23 @@ def query_doc_with_hybrid_search(
)
result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
sorted_items = sorted(
zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True
)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
"distances": [distances],
"documents": [documents],
"metadatas": [metadatas],
}
log.info(
@@ -134,52 +184,88 @@ def query_doc_with_hybrid_search(
)
return result
except Exception as e:
log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
raise e
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> list[dict]:
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_distances = []
combined_documents = []
combined_metadatas = []
combined_ids = []
for data in query_results:
combined_distances.extend(data["distances"][0])
for data in get_results:
combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, document, metadata)
combined = list(zip(combined_distances, combined_documents, combined_metadatas))
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
# We don't have anything :-(
if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
combined_ids.extend(data["ids"][0])
# Create the output dictionary
result = {
"distances": [sorted_distances],
"documents": [sorted_documents],
"metadatas": [sorted_metadatas],
"documents": [combined_documents],
"metadatas": [combined_metadatas],
"ids": [combined_ids],
}
return result
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
# Initialize lists to store combined data
combined = dict() # To store documents with unique document hashes
for data in query_results:
distances = data["distances"][0]
documents = data["documents"][0]
metadatas = data["metadatas"][0]
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.sha256(
document.encode()
).hexdigest() # Compute a hash for uniqueness
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
"documents": [list(sorted_documents)],
"metadatas": [list(sorted_metadatas)],
}
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
return merge_get_results(results)
def query_collection(
collection_names: list[str],
queries: list[str],
@@ -187,29 +273,49 @@ def query_collection(
k: int,
) -> dict:
results = []
for query in queries:
query_embedding = embedding_function(query)
for collection_name in collection_names:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
error = False
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
def process_query_collection(collection_name, query_embedding):
try:
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
@@ -218,39 +324,74 @@ def query_collection_with_hybrid_search(
embedding_function,
k: int,
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
) -> dict:
results = []
error = False
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
collection_results = {}
for collection_name in collection_names:
try:
for query in queries:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
results.append(result)
except Exception as e:
log.exception(
"Error when querying the collection with " f"hybrid_search: {e}"
log.debug(
f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}"
)
error = True
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_results[collection_name] = None
if error:
log.info(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
def process_query(collection_name, query):
try:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
)
return result, None
except Exception as e:
log.exception(f"Error when querying the collection with hybrid_search: {e}")
return None, e
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
tasks = [
(cn, q)
for cn in collection_names
if collection_results[cn] is not None
for q in queries
]
with ThreadPoolExecutor() as executor:
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
"Hybrid search failed for all collections. Using Non-hybrid search as fallback."
)
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k)
def get_embedding_function(
@@ -260,58 +401,132 @@ def get_embedding_function(
url,
key,
embedding_batch_size,
azure_api_version=None,
):
if embedding_engine == "":
return lambda query, user=None: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
func = lambda query, user=None: generate_embeddings(
return lambda query, prefix=None, user=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
def generate_multiple(query, user, func):
def generate_multiple(query, prefix, user, func):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
embeddings.extend(
func(query[i : i + embedding_batch_size], user=user)
func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
)
return embeddings
else:
return func(query, user)
return func(query, prefix, user)
return lambda query, user=None: generate_multiple(query, user, func)
return lambda query, prefix=None, user=None: generate_multiple(
query, prefix, user, func
)
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
def get_sources_from_files(
request,
files,
queries,
embedding_function,
k,
reranking_function,
k_reranker,
r,
hybrid_bm25_weight,
hybrid_search,
full_context=False,
):
log.debug(f"files: {files} {queries} {embedding_function} {reranking_function}")
log.debug(
f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}"
)
extracted_collections = []
relevant_contexts = []
for file in files:
if file.get("context") == "full":
context = None
if file.get("docs"):
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
context = {
"documents": [[doc.get("content") for doc in file.get("docs")]],
"metadatas": [[doc.get("metadata") for doc in file.get("docs")]],
}
elif file.get("context") == "full":
# Manual Full Mode Toggle
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]],
}
else:
context = None
elif (
file.get("type") != "web_search"
and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
# BYPASS_EMBEDDING_AND_RETRIEVAL
if file.get("type") == "collection":
file_ids = file.get("data", {}).get("file_ids", [])
documents = []
metadatas = []
for file_id in file_ids:
file_object = Files.get_file_by_id(file_id)
if file_object:
documents.append(file_object.data.get("content", ""))
metadatas.append(
{
"file_id": file_id,
"name": file_object.filename,
"source": file_object.filename,
}
)
context = {
"documents": [documents],
"metadatas": [metadatas],
}
elif file.get("id"):
file_object = Files.get_file_by_id(file.get("id"))
if file_object:
context = {
"documents": [[file_object.data.get("content", "")]],
"metadatas": [
[
{
"file_id": file.get("id"),
"name": file_object.filename,
"source": file_object.filename,
}
]
],
}
elif file.get("file").get("data"):
context = {
"documents": [[file.get("file").get("data", {}).get("content")]],
"metadatas": [
[file.get("file").get("data", {}).get("metadata", {})]
],
}
else:
collection_names = []
if file.get("type") == "collection":
if file.get("legacy"):
@@ -331,42 +546,52 @@ def get_sources_from_files(
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
if full_context:
try:
context = get_all_items_from_collections(collection_names)
except Exception as e:
log.exception(e)
else:
try:
context = None
if file.get("type") == "text":
context = file["content"]
else:
if hybrid_search:
try:
context = query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
r=r,
)
except Exception as e:
log.debug(
"Error when using hybrid search, using"
" non hybrid search as fallback."
)
if (not hybrid_search) or (context is None):
context = query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
except Exception as e:
log.exception(e)
extracted_collections.extend(collection_names)
if context:
if "data" in file:
del file["data"]
relevant_contexts.append({**context, "file": file})
sources = []
@@ -435,9 +660,17 @@ def generate_openai_batch_embeddings(
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post(
f"{url}/embeddings",
headers={
@@ -454,7 +687,7 @@ def generate_openai_batch_embeddings(
else {}
),
},
json={"input": texts, "model": model},
json=json_data,
)
r.raise_for_status()
data = r.json()
@@ -463,14 +696,80 @@ def generate_openai_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating openai batch embeddings: {e}")
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}"
)
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
r = requests.post(
f"{url}/api/embed",
headers={
@@ -487,7 +786,7 @@ def generate_ollama_batch_embeddings(
else {}
),
},
json={"input": texts, "model": model},
json=json_data,
)
r.raise_for_status()
data = r.json()
@@ -497,37 +796,55 @@ def generate_ollama_batch_embeddings(
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
url = kwargs.get("url", "")
key = kwargs.get("key", "")
user = kwargs.get("user")
if engine == "ollama":
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
**{"model": model, "texts": text, "url": url, "key": key, "user": user}
)
text = [f"{prefix}{text_element}" for text_element in text]
else:
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"user": user,
}
)
text = f"{prefix}{text}"
if engine == "ollama":
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
else:
embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
embeddings = generate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
@@ -563,13 +880,15 @@ class RerankCompressor(BaseDocumentCompressor):
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
docs_with_scores = list(
zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
)
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score

View File

@@ -1,22 +0,0 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

44
backend/open_webui/retrieval/vector/dbs/chroma.py Normal file → Executable file
View File

@@ -1,10 +1,16 @@
import chromadb
import logging
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@@ -16,9 +22,13 @@ from open_webui.config import (
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,
@@ -70,10 +80,16 @@ class ChromaClient:
n_results=limit,
)
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]]
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"distances": distances,
"documents": result["documents"],
"metadatas": result["metadatas"],
}
@@ -102,8 +118,7 @@ class ChromaClient:
}
)
return None
except Exception as e:
print(e)
except:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
@@ -162,12 +177,19 @@ class ChromaClient:
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
try:
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
pass
def reset(self):
# Resets the database. This will delete all collections and item entries.

View File

@@ -0,0 +1,300 @@
from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_API_KEY,
ELASTICSEARCH_USERNAME,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_CLOUD_ID,
ELASTICSEARCH_INDEX_PREFIX,
SSL_ASSERT_FINGERPRINT,
)
class ElasticsearchClient(VectorDBBase):
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
an index for each file but store it as a text field, while seperating to different index
baesd on the embedding length.
"""
def __init__(self):
self.index_prefix = ELASTICSEARCH_INDEX_PREFIX
self.client = Elasticsearch(
hosts=[ELASTICSEARCH_URL],
ca_certs=ELASTICSEARCH_CA_CERTS,
api_key=ELASTICSEARCH_API_KEY,
cloud_id=ELASTICSEARCH_CLOUD_ID,
basic_auth=(
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
else None
),
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
)
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
if not result:
return None
ids = []
documents = []
metadatas = []
for hit in result:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
"dynamic_templates": [
{
"strings": {
"match_mapping_type": "string",
"mapping": {"type": "keyword"},
}
}
],
"properties": {
"collection": {"type": "keyword"},
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "cosine",
},
"text": {"type": "text"},
"metadata": {"type": "object"},
},
}
}
self.client.indices.create(index=self._get_index_name(dimension), body=body)
# Status: works
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : min(i + batch_size, len(items))]
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
return result.body["count"] > 0
except Exception as e:
return None
def delete_collection(self, collection_name: str):
query = {"query": {"term": {"collection": collection_name}}}
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
# Status: works
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(len(vectors[0])), body=query
)
return self._result_to_search_result(result)
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}*",
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
self._create_index(dimension=dimension)
# Status: works
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
return self._scan_result_to_get_result(results)
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
bulk(self.client, actions)
# Upsert documents using the update API with doc_as_upsert=True.
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "update",
"_index": self._get_index_name(dimension=len(item["vector"])),
"_id": item["id"],
"doc": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}
for item in batch
]
bulk(self.client, actions)
# Delete specific documents from a collection by filtering on both collection and document IDs.
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
}
# logic based on chromaDB
if ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
elif filter:
for field, value in filter.items():
query["query"]["bool"]["filter"].append(
{"term": {f"metadata.{field}": value}}
)
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*")
for index in indices:
self.client.indices.delete(index=index)

View File

@@ -1,30 +1,42 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
MILVUS_TOKEN,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB)
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else:
self.client = Client(uri=MILVUS_URI, database=MILVUS_DB, token=MILVUS_TOKEN)
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for match in result:
_ids = []
_documents = []
@@ -33,11 +45,9 @@ class MilvusClient:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
@@ -51,24 +61,23 @@ class MilvusClient:
distances = []
documents = []
metadatas = []
for match in result:
_ids = []
_distances = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return SearchResult(
**{
"ids": ids,
@@ -101,11 +110,39 @@ class MilvusClient:
)
index_params = self.client.prepare_index_params()
# Use configurations from config.py
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
index_creation_params = {}
if index_type == "HNSW":
index_creation_params = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index(
field_name="vector",
index_type="HNSW",
metric_type="COSINE",
params={"M": 16, "efConstruction": 100},
index_type=index_type,
metric_type=metric_type,
params=index_creation_params,
)
self.client.create_collection(
@@ -113,6 +150,9 @@ class MilvusClient:
schema=schema,
index_params=index_params,
)
log.info(
f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
)
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
@@ -133,82 +173,113 @@ class MilvusClient:
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
# search_params=search_params # Potentially add later if needed
)
return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
max_limit = 16383 # The maximum number of records per request
all_results = []
if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = (
16384 * 10
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(
f"Limit not specified for query, fetching up to {limit} results in batches."
)
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
try:
log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
)
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
print("remaining", remaining)
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
max_limit, remaining if isinstance(remaining, int) else max_limit
)
log.debug(
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
)
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=["*"],
output_fields=[
"id",
"data",
"metadata",
], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch,
offset=offset,
)
if not results:
log.debug("No more results from query.")
break
all_results.extend(results)
results_count = len(results)
remaining -= (
results_count # Decrease remaining by the number of items fetched
)
log.debug(f"Fetched {results_count} results in this batch.")
if isinstance(remaining, int):
remaining -= results_count
offset += results_count
# Break the loop if the results returned are less than the requested fetch count
# Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch:
log.debug(
"Fetched less than requested, assuming end of results for this query."
)
break
print(all_results)
log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results])
except Exception as e:
print(e)
log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
# Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_")
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
log.warning(
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
)
return self._result_to_get_result([result])
# Using query with a trivial filter to get all items.
# This will use the paginated query logic.
return self.query(collection_name=collection_name, filter={}, limit=None)
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
@@ -216,10 +287,23 @@ class MilvusClient:
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
log.info(
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
@@ -239,10 +323,23 @@ class MilvusClient:
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension."
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
log.info(
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
@@ -262,30 +359,55 @@ class MilvusClient:
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
# Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
log.warning(
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None
if ids:
log.info(
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
elif filter:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
)
else:
log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
)
return None
def reset(self):
# Resets the database. This will delete all collections and item entries.
# Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
collection_names = self.client.list_collections()
for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix):
self.client.drop_collection(collection_name=collection_name)
deleted_collections = []
for collection_name_full in collection_names:
if collection_name_full.startswith(self.collection_prefix):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")

View File

@@ -1,7 +1,13 @@
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
@@ -11,7 +17,7 @@ from open_webui.config import (
)
class OpenSearchClient:
class OpenSearchClient(VectorDBBase):
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(
@@ -21,7 +27,13 @@ class OpenSearchClient:
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
)
def _get_index_name(self, collection_name: str) -> str:
return f"{self.index_prefix}_{collection_name}"
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = []
documents = []
metadatas = []
@@ -31,9 +43,12 @@ class OpenSearchClient:
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _result_to_search_result(self, result) -> SearchResult:
if not result["hits"]["hits"]:
return None
ids = []
distances = []
documents = []
@@ -46,72 +61,88 @@ class OpenSearchClient:
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
def _create_index(self, index_name: str, dimension: int):
def _create_index(self, collection_name: str, dimension: int):
body = {
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"type": "knn_vector",
"dimension": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
"m": 16,
"parameters": {
"ef_construction": 128,
"m": 16,
},
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
}
}
},
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
self.client.indices.create(
index=self._get_index_name(collection_name), body=body
)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
def has_collection(self, index_name: str) -> bool:
def has_collection(self, collection_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
return self.client.indices.exists(index=self._get_index_name(collection_name))
def delete_colleciton(self, index_name: str):
def delete_collection(self, collection_name: str):
# delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
self.client.indices.delete(index=self._get_index_name(collection_name))
def search(
self, index_name: str, vectors: list[list[float]], limit: int
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
}
},
}
try:
if not self.has_collection(collection_name):
return None
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
"params": {
"field": "vector",
"query_value": vectors[0],
}, # Assuming single query vector
},
}
},
}
return self._result_to_search_result(result)
result = self.client.search(
index=self._get_index_name(collection_name), body=query
)
return self._result_to_search_result(result)
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
@@ -125,13 +156,15 @@ class OpenSearchClient:
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}_{collection_name}",
index=self._get_index_name(collection_name),
body=query_body,
size=size,
)
@@ -141,64 +174,88 @@ class OpenSearchClient:
except Exception as e:
return None
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def _create_index_if_not_exists(self, collection_name: str, dimension: int):
if not self.has_collection(collection_name):
self._create_index(collection_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
index=self._get_index_name(collection_name), body=query
)
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
def insert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
"_op_type": "index",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
self.client.bulk(actions)
bulk(self.client, actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
"_op_type": "update",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"doc": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
"doc_as_upsert": True,
}
for item in batch
]
self.client.bulk(actions)
bulk(self.client, actions)
def delete(self, index_name: str, ids: list[str]):
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
self.client.bulk(body=actions)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
if ids:
actions = [
{
"_op_type": "delete",
"_index": self._get_index_name(collection_name),
"_id": id,
}
for id in ids
]
bulk(self.client, actions)
elif filter:
query_body = {
"query": {"bool": {"filter": []}},
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"match": {"metadata." + str(field): value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")

View File

@@ -1,4 +1,5 @@
from typing import Optional, List, Dict, Any
import logging
from sqlalchemy import (
cast,
column,
@@ -21,12 +22,22 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
@@ -38,7 +49,7 @@ class DocumentChunk(Base):
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
@@ -82,10 +93,10 @@ class PgvectorClient:
)
)
self.session.commit()
print("Initialization complete.")
log.info("Initialization complete.")
except Exception as e:
self.session.rollback()
print(f"Error during initialization: {e}")
log.exception(f"Error during initialization: {e}")
raise
def check_vector_length(self) -> None:
@@ -130,9 +141,8 @@ class PgvectorClient:
# Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
raise Exception(
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
)
# Truncate the vector to VECTOR_LENGTH
vector = vector[:VECTOR_LENGTH]
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -150,12 +160,12 @@ class PgvectorClient:
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
print(
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during insert: {e}")
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -184,10 +194,12 @@ class PgvectorClient:
)
self.session.add(new_chunk)
self.session.commit()
print(f"Upserted {len(items)} items into collection '{collection_name}'.")
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
print(f"Error during upsert: {e}")
log.exception(f"Error during upsert: {e}")
raise
def search(
@@ -270,7 +282,9 @@ class PgvectorClient:
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
distances[qid].append((2.0 - row.distance) / 2.0)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)
@@ -278,7 +292,7 @@ class PgvectorClient:
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
except Exception as e:
print(f"Error during search: {e}")
log.exception(f"Error during search: {e}")
return None
def query(
@@ -310,7 +324,7 @@ class PgvectorClient:
metadatas=metadatas,
)
except Exception as e:
print(f"Error during query: {e}")
log.exception(f"Error during query: {e}")
return None
def get(
@@ -334,7 +348,7 @@ class PgvectorClient:
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
print(f"Error during get: {e}")
log.exception(f"Error during get: {e}")
return None
def delete(
@@ -356,22 +370,22 @@ class PgvectorClient:
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
print(f"Deleted {deleted} items from collection '{collection_name}'.")
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
print(f"Error during delete: {e}")
log.exception(f"Error during delete: {e}")
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
print(
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
except Exception as e:
self.session.rollback()
print(f"Error during reset: {e}")
log.exception(f"Error during reset: {e}")
raise
def close(self) -> None:
@@ -387,9 +401,9 @@ class PgvectorClient:
)
return exists
except Exception as e:
print(f"Error checking collection existence: {e}")
log.exception(f"Error checking collection existence: {e}")
return False
def delete_collection(self, collection_name: str) -> None:
self.delete(collection_name)
print(f"Collection '{collection_name}' deleted.")
log.info(f"Collection '{collection_name}' deleted.")

View File

@@ -0,0 +1,508 @@
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_ENVIRONMENT,
PINECONE_INDEX_NAME,
PINECONE_DIMENSION,
PINECONE_METRIC,
PINECONE_CLOUD,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
# Validate required configuration
self._validate_config()
# Store configuration values
self.api_key = PINECONE_API_KEY
self.environment = PINECONE_ENVIRONMENT
self.index_name = PINECONE_INDEX_NAME
self.dimension = PINECONE_DIMENSION
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance
self.client = Pinecone(api_key=self.api_key)
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
# Create index if it doesn't exist
self._initialize_index()
def _validate_config(self) -> None:
"""Validate that all required configuration variables are set."""
missing_vars = []
if not PINECONE_API_KEY:
missing_vars.append("PINECONE_API_KEY")
if not PINECONE_ENVIRONMENT:
missing_vars.append("PINECONE_ENVIRONMENT")
if not PINECONE_INDEX_NAME:
missing_vars.append("PINECONE_INDEX_NAME")
if not PINECONE_DIMENSION:
missing_vars.append("PINECONE_DIMENSION")
if not PINECONE_CLOUD:
missing_vars.append("PINECONE_CLOUD")
if missing_vars:
raise ValueError(
f"Required configuration missing: {', '.join(missing_vars)}"
)
def _initialize_index(self) -> None:
"""Initialize the Pinecone index."""
try:
# Check if index exists
if self.index_name not in self.client.list_indexes().names():
log.info(f"Creating Pinecone index '{self.index_name}'...")
self.client.create_index(
name=self.index_name,
dimension=self.dimension,
metric=self.metric,
spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
)
log.info(f"Successfully created Pinecone index '{self.index_name}'")
else:
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(self.index_name)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
"""Convert VectorItem objects to Pinecone point format."""
points = []
for item in items:
# Start with any existing metadata or an empty dict
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
# Add text to metadata if available
if "text" in item:
metadata["text"] = item["text"]
# Always add collection_name to metadata for filtering
metadata["collection_name"] = collection_name_with_prefix
point = {
"id": item["id"],
"values": item["vector"],
"metadata": metadata,
}
points.append(point)
return points
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
"""Get the collection name with prefix."""
return f"{self.collection_prefix}_{collection_name}"
def _normalize_distance(self, score: float) -> float:
"""Normalize distance score based on the metric used."""
if self.metric.lower() == "cosine":
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
return (score + 1.0) / 2.0
elif self.metric.lower() in ["euclidean", "dotproduct"]:
# These are already suitable for ranking (smaller is better for Euclidean)
return score
else:
# For other metrics, use as is
return score
def _result_to_get_result(self, matches: list) -> GetResult:
"""Convert Pinecone matches to GetResult format."""
ids = []
documents = []
metadatas = []
for match in matches:
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def has_collection(self, collection_name: str) -> bool:
"""Check if a collection exists by searching for at least one item."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Search for at least 1 item with this collection name in metadata
response = self.index.query(
vector=[0.0] * self.dimension, # dummy vector
top_k=1,
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
matches = getattr(response, "matches", []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
)
return False
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection by removing all vectors with the collection name in metadata."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
self.index.delete(filter={"collection_name": collection_name_with_prefix})
log.info(
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
)
except Exception as e:
log.warning(
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
)
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert vectors into a collection."""
if not items:
log.warning("No items to insert")
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch inserts for performance
executor = self._executor
futures = []
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
futures.append(executor.submit(self.index.upsert, vectors=batch))
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log.error(f"Error inserting batch: {e}")
raise
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Upsert (insert or update) vectors into a collection."""
if not items:
log.warning("No items to upsert")
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch upserts for performance
executor = self._executor
futures = []
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
futures.append(executor.submit(self.index.upsert, vectors=batch))
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log.error(f"Error upserting batch: {e}")
raise
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of insert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to insert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of upsert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to upsert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
if not vectors or not vectors[0]:
log.warning("No vectors provided for search")
return None
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Search using the first vector (assuming this is the intended behavior)
query_vector = vectors[0]
# Perform the search
query_response = self.index.query(
vector=query_vector,
top_k=limit,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
documents=[[]],
metadatas=[[]],
distances=[[]],
)
# Convert to GetResult format
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=distances,
)
except Exception as e:
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors by metadata filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Create a zero vector for the dimension as Pinecone requires a vector
zero_vector = [0.0] * self.dimension
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Perform metadata-only query
query_response = self.index.query(
vector=zero_vector,
filter=pinecone_filter,
top_k=limit,
include_metadata=True,
)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""Get all vectors in a collection."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Use a zero vector for fetching all entries
zero_vector = [0.0] * self.dimension
# Add filter to only get vectors for this collection
query_response = self.index.query(
vector=zero_vector,
top_k=NO_LIMIT,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by IDs or filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
if ids:
# Delete by IDs (in batches for large deletions)
for i in range(0, len(ids), BATCH_SIZE):
batch_ids = ids[i : i + BATCH_SIZE]
# Note: When deleting by ID, we can't filter by collection_name
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
)
elif filter:
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Delete by metadata filter
self.index.delete(filter=pinecone_filter)
log.info(
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
)
else:
log.warning("No ids or filter provided for delete operation")
except Exception as e:
log.error(f"Error deleting from collection '{collection_name}': {e}")
raise
def reset(self) -> None:
"""Reset the database by deleting all collections."""
try:
self.index.delete(delete_all=True)
log.info("All vectors successfully deleted from the index.")
except Exception as e:
log.error(f"Failed to reset Pinecone index: {e}")
raise
def close(self):
"""Shut down resources."""
try:
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to clean up Pinecone resources: {e}")
self._executor.shutdown(wait=True)
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager, ensuring resources are cleaned up."""
self.close()

View File

@@ -1,25 +1,60 @@
from typing import Optional
import logging
from urllib.parse import urlparse
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
def _result_to_get_result(self, points) -> GetResult:
ids = []
@@ -45,11 +80,13 @@ class QdrantClient:
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
log.info(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
@@ -94,7 +131,8 @@ class QdrantClient:
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@@ -120,7 +158,7 @@ class QdrantClient:
)
return self._result_to_get_result(points.points)
except Exception as e:
print(e)
log.exception(f"Error querying a collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:

View File

@@ -0,0 +1,712 @@
import logging
from typing import Optional, Tuple
from urllib.parse import urlparse
import grpc
from open_webui.config import (
QDRANT_API_KEY,
QDRANT_GRPC_PORT,
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
Maps the traditional collection name to multi-tenant collection and tenant ID.
Returns:
tuple: (collection_name, tenant_id)
"""
# Check for user memory collections
tenant_id = collection_name
if collection_name.startswith("user-memory-"):
return self.MEMORY_COLLECTION, tenant_id
# Check for file collections
elif collection_name.startswith("file-"):
return self.FILE_COLLECTION, tenant_id
# Check for web search collections
elif collection_name.startswith("web-search-"):
return self.WEB_SEARCH_COLLECTION, tenant_id
# Handle hash-based collections (YouTube and web URLs)
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, tenant_id
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
def _extract_error_message(self, exception):
"""
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
):
"""
Creates a collection with multi-tenancy configuration if it doesn't exist.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
"""
try:
# Try to create the collection directly - will fail if it already exists
self.client.create_collection(
collection_name=mt_collection_name,
vectors_config=models.VectorParams(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # Enable per-tenant indexing
m=0,
on_disk=self.QDRANT_ON_DISK,
),
)
# Create tenant ID payload index
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name="tenant_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
wait=True,
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
isinstance(e, grpc.RpcError)
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
"""
Create point structs from vector items with tenant ID.
"""
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"],
"tenant_id": tenant_id,
},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=1,
)
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = []
if ids:
for id_value in ids:
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter:
for key, value in filter.items():
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try:
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
)
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Set default limit if not provided
if limit is None:
limit = NO_LIMIT
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=combined_filter,
limit=limit,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Get all items in a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry(
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"upsert", mt_collection, points, dimension
)
def reset(self):
"""
Reset the database by deleting all collections.
"""
if not self.client:
return None
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str):
"""
Delete a collection.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View File

@@ -0,0 +1,55 @@
from open_webui.retrieval.vector.main import VectorDBBase
from open_webui.retrieval.vector.type import VectorType
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
class Vector:
@staticmethod
def get_vector(vector_type: str) -> VectorDBBase:
"""
get vector db instance by vector type
"""
match vector_type:
case VectorType.MILVUS:
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
case VectorType.QDRANT:
if ENABLE_QDRANT_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
QdrantClient,
)
return QdrantClient()
else:
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
return QdrantClient()
case VectorType.PINECONE:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient()
case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
return OpenSearchClient()
case VectorType.PGVECTOR:
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
return PgvectorClient()
case VectorType.ELASTICSEARCH:
from open_webui.retrieval.vector.dbs.elasticsearch import (
ElasticsearchClient,
)
return ElasticsearchClient()
case VectorType.CHROMA:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
return ChromaClient()
case _:
raise ValueError(f"Unsupported vector type: {vector_type}")
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel
from typing import Optional, List, Any
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class VectorItem(BaseModel):
@@ -17,3 +18,69 @@ class GetResult(BaseModel):
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]
class VectorDBBase(ABC):
"""
Abstract base class for all vector database backends.
Implementations of this class provide methods for collection management,
vector insertion, deletion, similarity search, and metadata filtering.
Any custom vector database integration must inherit from this class and
implement all abstract methods.
"""
@abstractmethod
def has_collection(self, collection_name: str) -> bool:
"""Check if the collection exists in the vector DB."""
pass
@abstractmethod
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection from the vector DB."""
pass
@abstractmethod
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert a list of vector items into a collection."""
pass
@abstractmethod
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert or update vector items in a collection."""
pass
@abstractmethod
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
pass
@abstractmethod
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors from a collection using metadata filter."""
pass
@abstractmethod
def get(self, collection_name: str) -> Optional[GetResult]:
"""Retrieve all vectors from a collection."""
pass
@abstractmethod
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by ID or filter from a collection."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by removing all collections or those matching a condition."""
pass

View File

@@ -0,0 +1,11 @@
from enum import StrEnum
class VectorType(StrEnum):
MILVUS = "milvus"
QDRANT = "qdrant"
CHROMA = "chroma"
PINECONE = "pinecone"
ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch"
PGVECTOR = "pgvector"

View File

@@ -3,6 +3,7 @@ from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from duckduckgo_search.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -22,29 +23,24 @@ def search_duckduckgo(
list[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
search_results = []
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
try:
search_results = ddgs.text(
query, safesearch="moderate", max_results=count, backend="lite"
)
)
except RatelimitException as e:
log.error(f"RatelimitException: {e}")
if filter_list:
results = get_filtered_results(results, filter_list)
search_results = get_filtered_results(search_results, filter_list)
# Return the list of search results
return results
return [
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
for result in search_results
]

View File

@@ -0,0 +1,47 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external(
external_url: str,
external_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
response = requests.post(
external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={
"query": query,
"count": count,
},
)
response.raise_for_status()
results = response.json()
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("link"),
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View File

@@ -0,0 +1,49 @@
import logging
from typing import Optional, List
from urllib.parse import urljoin
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_firecrawl(
firecrawl_url: str,
firecrawl_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
response = requests.post(
firecrawl_search_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {firecrawl_api_key}",
},
json={
"query": query,
"limit": count,
},
)
response.raise_for_status()
results = response.json().get("data", [])
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View File

@@ -0,0 +1,87 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_perplexity(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
"""
# Handle PersistentConfig object
if hasattr(api_key, "__str__"):
api_key = str(api_key)
try:
url = "https://api.perplexity.ai/chat/completions"
# Create payload for the API call
payload = {
"model": "sonar",
"messages": [
{
"role": "system",
"content": "You are a search assistant. Provide factual information with citations.",
},
{"role": "user", "content": query},
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Make the API request
response = requests.request("POST", url, json=payload, headers=headers)
# Parse the JSON response
json_response = response.json()
# Extract citations from the response
citations = json_response.get("citations", [])
# Create search results from citations
results = []
for i, citation in enumerate(citations[:count]):
# Extract content from the response to use as snippet
content = ""
if "choices" in json_response and json_response["choices"]:
if i == 0:
content = json_response["choices"][0]["message"]["content"]
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
results.append(result)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
for result in results[:count]
]
except Exception as e:
log.error(f"Error searching with Perplexity API: {e}")
return []

View File

@@ -42,7 +42,9 @@ def search_searchapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -42,7 +42,9 @@ def search_serpapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -0,0 +1,60 @@
import logging
import json
from typing import Optional, List
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_sougou(
sougou_api_sid: str,
sougou_api_sk: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
from tencentcloud.common.common_client import CommonClient
from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
try:
cred = credential.Credential(sougou_api_sid, sougou_api_sk)
http_profile = HttpProfile()
http_profile.endpoint = "tms.tencentcloudapi.com"
client_profile = ClientProfile()
client_profile.http_profile = http_profile
params = json.dumps({"Query": query, "Cnt": 20})
common_client = CommonClient(
"tms", "2020-12-29", cred, "", profile=client_profile
)
results = [
json.loads(page)
for page in common_client.call_json("SearchPro", json.loads(params))[
"Response"
]["Pages"]
]
sorted_results = sorted(
results, key=lambda x: x.get("scour", 0.0), reverse=True
)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("passage"),
)
for result in sorted_results[:count]
]
except TencentCloudSDKException as err:
log.error(f"Error in Sougou search: {err}")
return []

View File

@@ -1,32 +1,45 @@
import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_tavily(
api_key: str,
query: str,
count: int,
filter_list: Optional[list[str]] = None,
# **kwargs,
) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Tavily Search API key
query (str): The query to search for
count (int): The maximum number of results to return
Returns:
list[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
data = {"query": query, "max_results": count}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
json_response = response.json()
raw_search_results = json_response.get("results", [])
results = json_response.get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
@@ -34,5 +47,5 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in raw_search_results[:count]
for result in results
]

View File

@@ -1,19 +1,45 @@
import socket
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from langchain_community.document_loaders import (
WebBaseLoader,
)
from langchain_core.documents import Document
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
from open_webui.env import SRC_LOG_LEVELS
import asyncio
import logging
import socket
import ssl
import urllib.parse
import urllib.request
from collections import defaultdict
from datetime import datetime, time, timedelta
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
Literal,
)
import aiohttp
import certifi
import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
PLAYWRIGHT_WS_URL,
PLAYWRIGHT_TIMEOUT,
WEB_LOADER_ENGINE,
FIRECRAWL_API_BASE_URL,
FIRECRAWL_API_KEY,
TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -65,9 +91,472 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
return metadata
def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
return True
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
return False
class RateLimitMixin:
async def _wait_for_rate_limit(self):
"""Wait to respect the rate limit if specified."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
await asyncio.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
def _sync_wait_for_rate_limit(self):
"""Synchronous version of rate limit wait."""
if self.requests_per_second and self.last_request_time:
min_interval = timedelta(seconds=1.0 / self.requests_per_second)
time_since_last = datetime.now() - self.last_request_time
if time_since_last < min_interval:
time.sleep((min_interval - time_since_last).total_seconds())
self.last_request_time = datetime.now()
class URLProcessingMixin:
def _verify_ssl_cert(self, url: str) -> bool:
"""Verify SSL certificate for a URL."""
return verify_ssl_cert(url)
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
self._sync_wait_for_rate_limit()
return True
class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths,
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
mode: Literal["crawl", "scrape", "map"] = "scrape",
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
"""Concurrent document loader for FireCrawl operations.
Executes multiple FireCrawlLoader instances concurrently using thread pooling
to improve bulk processing efficiency.
Args:
web_paths: List of URLs/paths to process.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
requests_per_second: Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
api_key: API key for FireCrawl service. Defaults to None
(uses FIRE_CRAWL_API_KEY environment variable if not provided).
api_url: Base URL for FireCrawl API. Defaults to official API endpoint.
mode: Operation mode selection:
- 'crawl': Website crawling mode (default)
- 'scrape': Direct page scraping
- 'map': Site map generation
proxy: Proxy override settings for the FireCrawl API.
params: The parameters to pass to the Firecrawl API.
Examples include crawlerOptions.
For more details, visit: https://github.com/mendableai/firecrawl-py
"""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.trust_env = trust_env
self.continue_on_failure = continue_on_failure
self.api_key = api_key
self.api_url = api_url
self.mode = mode
self.params = params
def lazy_load(self) -> Iterator[Document]:
"""Load documents concurrently using FireCrawl."""
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
for document in loader.lazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
async def alazy_load(self):
"""Async version of lazy_load."""
for url in self.web_paths:
try:
await self._safe_process_url(url)
loader = FireCrawlLoader(
url=url,
api_key=self.api_key,
api_url=self.api_url,
mode=self.mode,
params=self.params,
)
async for document in loader.alazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def __init__(
self,
web_paths: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
continue_on_failure: bool = True,
requests_per_second: Optional[float] = None,
verify_ssl: bool = True,
trust_env: bool = False,
proxy: Optional[Dict[str, str]] = None,
):
"""Initialize SafeTavilyLoader with rate limiting and SSL verification support.
Args:
web_paths: List of URLs/paths to process.
api_key: The Tavily API key.
extract_depth: Depth of extraction ("basic" or "advanced").
continue_on_failure: Whether to continue if extraction of a URL fails.
requests_per_second: Number of requests per second to limit to.
verify_ssl: If True, verify SSL certificates.
trust_env: If True, use proxy settings from environment variables.
proxy: Optional proxy configuration.
"""
# Initialize proxy configuration if using environment variables
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# Store parameters for creating TavilyLoader instances
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
self.api_key = api_key
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.verify_ssl = verify_ssl
self.trust_env = trust_env
self.proxy = proxy
# Add rate limiting
self.requests_per_second = requests_per_second
self.last_request_time = None
def lazy_load(self) -> Iterator[Document]:
"""Load documents with rate limiting support, delegating to TavilyLoader."""
valid_urls = []
for url in self.web_paths:
try:
self._safe_process_url_sync(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
else:
raise e
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async version with rate limiting and SSL verification."""
valid_urls = []
for url in self.web_paths:
try:
await self._safe_process_url(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
return
raise ValueError("No valid URLs to process after SSL verification")
try:
loader = TavilyLoader(
urls=valid_urls,
api_key=self.api_key,
extract_depth=self.extract_depth,
continue_on_failure=self.continue_on_failure,
)
async for document in loader.alazy_load():
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading URLs: {e}")
else:
raise e
class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessingMixin):
"""Load HTML pages safely with Playwright, supporting SSL verification, rate limiting, and remote browser connection.
Attributes:
web_paths (List[str]): List of URLs to load.
verify_ssl (bool): If True, verify SSL certificates.
trust_env (bool): If True, use proxy settings from environment variables.
requests_per_second (Optional[float]): Number of requests per second to limit to.
continue_on_failure (bool): If True, continue loading other URLs on failure.
headless (bool): If True, the browser will run in headless mode.
proxy (dict): Proxy override settings for the Playwright session.
playwright_ws_url (Optional[str]): WebSocket endpoint URI for remote browser connection.
playwright_timeout (Optional[int]): Maximum operation time in milliseconds.
"""
def __init__(
self,
web_paths: List[str],
verify_ssl: bool = True,
trust_env: bool = False,
requests_per_second: Optional[float] = None,
continue_on_failure: bool = True,
headless: bool = True,
remove_selectors: Optional[List[str]] = None,
proxy: Optional[Dict[str, str]] = None,
playwright_ws_url: Optional[str] = None,
playwright_timeout: Optional[int] = 10000,
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get("server") if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
urls=web_paths,
continue_on_failure=continue_on_failure,
headless=headless if playwright_ws_url is None else False,
remove_selectors=remove_selectors,
proxy=proxy,
)
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
self.last_request_time = None
self.playwright_ws_url = playwright_ws_url
self.trust_env = trust_env
self.playwright_timeout = playwright_timeout
def lazy_load(self) -> Iterator[Document]:
"""Safely load URLs synchronously with support for remote browser."""
from playwright.sync_api import sync_playwright
with sync_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = p.chromium.connect(self.playwright_ws_url)
else:
browser = p.chromium.launch(headless=self.headless, proxy=self.proxy)
for url in self.urls:
try:
self._safe_process_url_sync(url)
page = browser.new_page()
response = page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = self.evaluator.evaluate(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
browser.close()
async def alazy_load(self) -> AsyncIterator[Document]:
"""Safely load URLs asynchronously with support for remote browser."""
from playwright.async_api import async_playwright
async with async_playwright() as p:
# Use remote browser if ws_endpoint is provided, otherwise use local browser
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
for url in self.urls:
try:
await self._safe_process_url(url)
page = await browser.new_page()
response = await page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
text = await self.evaluator.evaluate_async(page, browser, response)
metadata = {"source": url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
continue
raise e
await browser.close()
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url,
**(self.requests_kwargs | kwargs),
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
final_results = []
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
@@ -76,33 +565,86 @@ class SafeWebBaseLoader(WebBaseLoader):
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = extract_metadata(soup, path)
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
log.exception(f"Error loading {path}: {e}")
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
async def aload(self) -> list[Document]:
"""Load data into Document objects."""
return [document async for document in self.alazy_load()]
def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
trust_env: bool = False,
):
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader(
safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
)
web_loader_args = {
"web_paths": safe_urls,
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True,
"trust_env": trust_env,
}
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
WebLoaderClass = SafeWebBaseLoader
if WEB_LOADER_ENGINE.value == "playwright":
WebLoaderClass = SafePlaywrightURLLoader
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000
if PLAYWRIGHT_WS_URL.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
if WEB_LOADER_ENGINE.value == "firecrawl":
WebLoaderClass = SafeFireCrawlLoader
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
if WEB_LOADER_ENGINE.value == "tavily":
WebLoaderClass = SafeTavilyLoader
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
if WEB_LOADER_ENGINE.value == "external":
WebLoaderClass = ExternalWebLoader
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
if WebLoaderClass:
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using WEB_LOADER_ENGINE %s for %s URLs",
web_loader.__class__.__name__,
len(safe_urls),
)
return web_loader
else:
raise ValueError(
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
)

View File

@@ -0,0 +1,87 @@
import logging
from typing import Optional
import requests
from requests.auth import HTTPDigestAuth
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_yacy(
query_url: str,
username: Optional[str],
password: Optional[str],
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search a Yacy instance for a given query and return the results as a list of SearchResult objects.
The function accepts username and password for authenticating to Yacy.
Args:
query_url (str): The base URL of the Yacy server.
username (str): Optional YaCy username.
password (str): Optional YaCy password.
query (str): The search term or question to find in the Yacy database.
count (int): The maximum number of results to retrieve from the search.
Returns:
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Use authentication if either username or password is set
yacy_auth = None
if username or password:
yacy_auth = HTTPDigestAuth(username, password)
params = {
"query": query,
"contentdom": "text",
"resource": "global",
"maximumRecords": count,
"nav": "none",
}
# Check if provided a json API URL
if not query_url.endswith("yacysearch.json"):
# Strip all query parameters from the URL
query_url = query_url.rstrip("/") + "/yacysearch.json"
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
auth=yacy_auth,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("channels", [{}])[0].get("items", [])
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in sorted_results[:count]
]

View File

@@ -7,6 +7,9 @@ from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import aiohttp
import aiofiles
@@ -17,6 +20,7 @@ from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -33,10 +37,13 @@ from open_webui.config import (
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
CACHE_DIR,
WHISPER_LANGUAGE,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
ENV,
SRC_LOG_LEVELS,
DEVICE_TYPE,
@@ -47,13 +54,15 @@ from open_webui.env import (
router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
AZURE_MAX_FILE_SIZE_MB = 200
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@@ -67,27 +76,47 @@ from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
def is_audio_conversion_required(file_path):
"""
Check if the given audio file needs conversion to mp3.
"""
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
log.error(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
try:
info = mediainfo(file_path)
codec_name = info.get("codec_name", "").lower()
codec_type = info.get("codec_type", "").lower()
codec_tag_string = info.get("codec_tag_string", "").lower()
if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
# File is AAC/mp4a audio, recommend mp3 conversion
return True
# If the codec name is in the supported formats
if codec_name in SUPPORTED_FORMATS:
return False
return True
return False
except Exception as e:
log.error(f"Error getting audio format: {e}")
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
def convert_audio_to_mp3(file_path):
"""Convert audio file to mp3 format."""
try:
output_path = os.path.splitext(file_path)[0] + ".mp3"
audio = AudioSegment.from_file(file_path)
audio.export(output_path, format="mp3")
log.info(f"Converted {file_path} to {output_path}")
return output_path
except Exception as e:
log.error(f"Error converting audio file: {e}")
return None
def set_faster_whisper_model(model: str, auto_update: bool = False):
@@ -130,6 +159,7 @@ class TTSConfigForm(BaseModel):
VOICE: str
SPLIT_ON: str
AZURE_SPEECH_REGION: str
AZURE_SPEECH_BASE_URL: str
AZURE_SPEECH_OUTPUT_FORMAT: str
@@ -140,6 +170,11 @@ class STTConfigForm(BaseModel):
MODEL: str
WHISPER_MODEL: str
DEEPGRAM_API_KEY: str
AZURE_API_KEY: str
AZURE_REGION: str
AZURE_LOCALES: str
AZURE_BASE_URL: str
AZURE_MAX_SPEAKERS: str
class AudioConfigUpdateForm(BaseModel):
@@ -159,6 +194,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
@@ -168,6 +204,11 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
},
}
@@ -184,6 +225,9 @@ async def update_audio_config(
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
request.app.state.config.TTS_AZURE_SPEECH_BASE_URL = (
form_data.tts.AZURE_SPEECH_BASE_URL
)
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
)
@@ -194,6 +238,13 @@ async def update_audio_config(
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = (
form_data.stt.AZURE_MAX_SPEAKERS
)
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -210,6 +261,7 @@ async def update_audio_config(
"VOICE": request.app.state.config.TTS_VOICE,
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
"AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL,
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
},
"stt": {
@@ -219,6 +271,11 @@ async def update_audio_config(
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
},
}
@@ -265,8 +322,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
payload["model"] = request.app.state.config.TTS_MODEL
try:
# print(payload)
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
json=payload,
@@ -284,6 +343,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -309,7 +369,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -323,7 +383,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
)
try:
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
json={
@@ -336,6 +399,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json",
"xi-api-key": request.app.state.config.TTS_API_KEY,
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -360,7 +424,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -371,7 +435,8 @@ async def speech(request: Request, user=Depends(get_verified_user)):
log.exception(e)
raise HTTPException(status_code=400, detail="Invalid JSON payload")
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
region = request.app.state.config.TTS_AZURE_SPEECH_REGION or "eastus"
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
language = request.app.state.config.TTS_VOICE
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
@@ -380,15 +445,20 @@ async def speech(request: Request, user=Depends(get_verified_user)):
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
<voice name="{language}">{payload["input"]}</voice>
</speak>"""
async with aiohttp.ClientSession() as session:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
async with aiohttp.ClientSession(
timeout=timeout, trust_env=True
) as session:
async with session.post(
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
(base_url or f"https://{region}.tts.speech.microsoft.com")
+ "/cognitiveservices/v1",
headers={
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": output_format,
},
data=data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -413,7 +483,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status", 500),
status_code=getattr(r, "status", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
@@ -457,12 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
def transcribe(request: Request, file_path):
print("transcribe", file_path)
def transcription_handler(request, file_path, metadata):
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
metadata = metadata or {}
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -470,7 +541,12 @@ def transcribe(request: Request, file_path):
)
model = request.app.state.faster_whisper_model
segments, info = model.transcribe(file_path, beam_size=5)
segments, info = model.transcribe(
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=metadata.get("language") or WHISPER_LANGUAGE,
)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
@@ -487,11 +563,6 @@ def transcribe(request: Request, file_path):
log.debug(data)
return data
elif request.app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
r = None
try:
r = requests.post(
@@ -500,7 +571,14 @@ def transcribe(request: Request, file_path):
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
data={
"model": request.app.state.config.STT_MODEL,
**(
{"language": metadata.get("language")}
if metadata.get("language")
else {}
),
},
)
r.raise_for_status()
@@ -589,34 +667,254 @@ def transcribe(request: Request, file_path):
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
elif request.app.state.config.STT_ENGINE == "azure":
# Check file exists and size
if not os.path.exists(file_path):
raise HTTPException(status_code=400, detail="Audio file not found")
# Check file size (Azure has a larger limit of 200MB)
file_size = os.path.getsize(file_path)
if file_size > AZURE_MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
)
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
region = request.app.state.config.AUDIO_STT_AZURE_REGION or "eastus"
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS or 3
# IF NO LOCALES, USE DEFAULTS
if len(locales) < 2:
locales = [
"en-US",
"es-ES",
"es-MX",
"fr-FR",
"hi-IN",
"it-IT",
"de-DE",
"en-GB",
"en-IN",
"ja-JP",
"ko-KR",
"pt-BR",
"zh-CN",
]
locales = ",".join(locales)
if not api_key or not region:
raise HTTPException(
status_code=400,
detail="Azure API key is required for Azure STT",
)
r = None
try:
# Prepare the request
data = {
"definition": json.dumps(
{
"locales": locales.split(","),
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
}
if locales
else {}
)
}
url = (
base_url or f"https://{region}.api.cognitive.microsoft.com"
) + "/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
# Use context manager to ensure file is properly closed
with open(file_path, "rb") as audio_file:
r = requests.post(
url=url,
files={"audio": audio_file},
data=data,
headers={
"Ocp-Apim-Subscription-Key": api_key,
},
)
r.raise_for_status()
response = r.json()
# Extract transcript from response
if not response.get("combinedPhrases"):
raise ValueError("No transcription found in response")
# Get the full transcript from combinedPhrases
transcript = response["combinedPhrases"][0].get("text", "").strip()
if not transcript:
raise ValueError("Empty transcript in response")
data = {"text": transcript}
# Save transcript to json file (consistent with other providers)
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
log.debug(data)
return data
except (KeyError, IndexError, ValueError) as e:
log.exception("Error parsing Azure response")
raise HTTPException(
status_code=500,
detail=f"Failed to parse Azure response: {str(e)}",
)
except requests.exceptions.RequestException as e:
log.exception(e)
detail = None
try:
if r is not None and r.status_code != 200:
res = r.json()
if "error" in res:
detail = f"External: {res['error'].get('message', '')}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=getattr(r, "status_code", 500) if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path):
file_path = convert_audio_to_mp3(file_path)
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
# Always produce a list of chunk paths (could be one entry if small)
try:
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
print(f"Chunk paths: {chunk_paths}")
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
results = []
try:
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path, metadata)
for chunk_path in chunk_paths
]
# Gather results as they complete
for future in futures:
try:
results.append(future.result())
except Exception as transcribe_exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error transcribing chunk: {transcribe_exc}",
)
finally:
# Clean up only the temporary chunks, never the original file
for chunk_path in chunk_paths:
if chunk_path != file_path and os.path.isfile(chunk_path):
try:
os.remove(chunk_path)
except Exception:
pass
return {
"text": " ".join([result["text"] for result in results]),
}
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
id = os.path.splitext(os.path.basename(file_path))[
0
] # Handles names with multiple dots
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
audio.export(compressed_path, format="mp3", bitrate="32k")
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
return compressed_path
else:
return file_path
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
"""
Splits audio into chunks not exceeding max_bytes.
Returns a list of chunk file paths. If audio fits, returns list with original path.
"""
file_size = os.path.getsize(file_path)
if file_size <= max_bytes:
return [file_path] # Nothing to split
audio = AudioSegment.from_file(file_path)
duration_ms = len(audio)
orig_size = file_size
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
chunks = []
start = 0
i = 0
base, _ = os.path.splitext(file_path)
while start < duration_ms:
end = min(start + approx_chunk_ms, duration_ms)
chunk = audio[start:end]
chunk_path = f"{base}_chunk_{i}.{format}"
chunk.export(chunk_path, format=format, bitrate=bitrate)
# Reduce chunk duration if still too large
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
end = start + ((end - start) // 2)
chunk = audio[start:end]
chunk.export(chunk_path, format=format, bitrate=bitrate)
if os.path.getsize(chunk_path) > max_bytes:
os.remove(chunk_path)
raise Exception("Audio chunk cannot be reduced below max file size.")
chunks.append(chunk_path)
start = end
i += 1
return chunks
@router.post("/transcriptions")
def transcription(
request: Request,
file: UploadFile = File(...),
language: Optional[str] = Form(None),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
if not (
file.content_type.startswith("audio/")
or file.content_type in SUPPORTED_CONTENT_TYPES
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
@@ -637,19 +935,18 @@ def transcription(
f.write(contents)
try:
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
metadata = None
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
if language:
metadata = {"language": language}
result = transcribe(request, file_path, metadata)
return {
**result,
"filename": os.path.basename(file_path),
}
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)
@@ -670,7 +967,22 @@ def transcription(
def get_available_models(request: Request) -> list[dict]:
available_models = []
if request.app.state.config.TTS_ENGINE == "openai":
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
)
response.raise_for_status()
data = response.json()
available_models = data.get("models", [])
except Exception as e:
log.error(f"Error fetching models from custom endpoint: {str(e)}")
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
else:
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
response = requests.get(
@@ -701,14 +1013,37 @@ def get_available_voices(request) -> dict:
"""Returns {voice_id: voice_name} dict"""
available_voices = {}
if request.app.state.config.TTS_ENGINE == "openai":
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
# Use custom endpoint if not using the official OpenAI API URL
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
"https://api.openai.com"
):
try:
response = requests.get(
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
)
response.raise_for_status()
data = response.json()
voices_list = data.get("voices", [])
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
except Exception as e:
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
else:
available_voices = {
"alloy": "alloy",
"echo": "echo",
"fable": "fable",
"onyx": "onyx",
"nova": "nova",
"shimmer": "shimmer",
}
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
try:
available_voices = get_elevenlabs_voices(
@@ -720,7 +1055,10 @@ def get_available_voices(request) -> dict:
elif request.app.state.config.TTS_ENGINE == "azure":
try:
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
base_url = request.app.state.config.TTS_AZURE_SPEECH_BASE_URL
url = (
base_url or f"https://{region}.tts.speech.microsoft.com"
) + "/cognitiveservices/voices/list"
headers = {
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
}

View File

@@ -19,40 +19,45 @@ from open_webui.models.auths import (
UserResponse,
)
from open_webui.models.users import Users
from open_webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response
from open_webui.config import (
OPENID_PROVIDER_URL,
ENABLE_OAUTH_SIGNUP,
)
from fastapi.responses import RedirectResponse, Response, JSONResponse
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import (
decode_token,
create_api_key,
create_token,
get_admin_user,
get_verified_user,
get_current_user,
get_password_hash,
get_http_authorization_cred,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
if ENABLE_LDAP.value:
from ldap3 import Server, Connection, NONE, Tls
from ldap3.utils.conv import escape_filter_chars
router = APIRouter()
@@ -73,31 +78,36 @@ class SessionUserResponse(Token, UserResponse):
async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
):
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
auth_header = request.headers.get("Authorization")
auth_token = get_http_authorization_cred(auth_header)
token = auth_token.credentials
data = decode_token(token)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=expires_delta,
)
if data:
expires_at = data.get("exp")
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
)
if (expires_at is not None) and int(time.time()) > expires_at:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=datetime_expires_at,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=(
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
@@ -178,6 +188,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_VALIDATE_CERT = (
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
)
LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS
@@ -189,14 +202,14 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
try:
tls = Tls(
validate=CERT_REQUIRED,
validate=LDAP_VALIDATE_CERT,
version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS,
)
except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}")
raise HTTPException(400, detail=str(e))
log.error(f"TLS configuration error: {str(e)}")
raise HTTPException(400, detail="Failed to configure TLS for LDAP connection.")
try:
server = Server(
@@ -211,7 +224,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_DN,
LDAP_APP_PASSWORD,
auto_bind="NONE",
authentication="SIMPLE",
authentication="SIMPLE" if LDAP_APP_DN else "ANONYMOUS",
)
if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed")
@@ -226,14 +239,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
],
)
if not search_success:
if not search_success or not connection_app.entries:
raise HTTPException(400, detail="User not found in the LDAP server")
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
if not mail or mail == "" or mail == "[]":
raise HTTPException(400, f"User {form_data.user} does not have mail.")
email = entry[
f"{LDAP_ATTRIBUTE_FOR_MAIL}"
].value # retrieve the Attribute value
if not email:
raise HTTPException(400, "User does not have a valid email address.")
elif isinstance(email, str):
email = email.lower()
elif isinstance(email, list):
email = email[0].lower()
else:
email = str(email).lower()
cn = str(entry["cn"])
user_dn = entry.entry_dn
@@ -246,19 +268,24 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
authentication="SIMPLE",
)
if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}")
raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(mail)
user = Users.get_user_by_email(email)
if not user:
try:
user_count = Users.get_num_users()
role = (
"admin"
if Users.get_num_users() == 0
if user_count == 0
else request.app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=mail, password=str(uuid.uuid4()), name=cn, role=role
email=email,
password=str(uuid.uuid4()),
name=cn,
role=role,
)
if not user:
@@ -269,23 +296,38 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
except HTTPException:
raise
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"LDAP user creation error: {str(err)}")
raise HTTPException(
500, detail="Internal error occurred during LDAP user creation."
)
user = Auths.authenticate_user_by_trusted_header(mail)
user = Auths.authenticate_user_by_email(email)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(
request.app.state.config.JWT_EXPIRES_IN
),
expires_delta=expires_delta,
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=(
datetime.datetime.fromtimestamp(
expires_at, datetime.timezone.utc
)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
@@ -295,6 +337,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@@ -305,12 +348,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else:
raise HTTPException(
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
raise HTTPException(400, "User record mismatch.")
except Exception as e:
raise HTTPException(400, detail=str(e))
log.error(f"LDAP authentication error: {str(e)}")
raise HTTPException(400, detail="LDAP authentication failed.")
############################
@@ -324,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
trusted_name = trusted_email
email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
name = email
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
)
if not Users.get_user_by_email(trusted_email.lower()):
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
if not Users.get_user_by_email(email.lower()):
await signup(
request,
response,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
),
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
user = Auths.authenticate_user_by_email(email)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
).split(",")
group_names = [name.strip() for name in group_names if name.strip()]
if group_names:
Groups.sync_user_groups_by_group_names(user.id, group_names)
elif WEBUI_AUTH == False:
admin_email = "admin@localhost"
admin_password = "admin"
@@ -413,6 +462,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
if WEBUI_AUTH:
if (
not request.app.state.config.ENABLE_SIGNUP
@@ -427,6 +477,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
user_count = Users.get_num_users()
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -437,14 +488,15 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
try:
role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
)
if Users.get_num_users() == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
if len(form_data.password.encode("utf-8")) > 72:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
@@ -484,6 +536,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.WEBUI_NAME,
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
@@ -497,6 +550,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
user.id, request.app.state.config.USER_PERMISSIONS
)
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
return {
"token": token,
"token_type": "Bearer",
@@ -511,7 +568,8 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"Signup error: {str(err)}")
raise HTTPException(500, detail="An internal error occurred during signup.")
@router.get("/signout")
@@ -529,8 +587,14 @@ async def signout(request: Request, response: Response):
logout_url = openid_data.get("end_session_endpoint")
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
url=f"{logout_url}?id_token_hint={oauth_id_token}"
return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
},
headers=response.headers,
)
else:
raise HTTPException(
@@ -538,9 +602,25 @@ async def signout(request: Request, response: Response):
detail="Failed to fetch OpenID configuration",
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
log.error(f"OpenID signout error: {str(e)}")
raise HTTPException(
status_code=500,
detail="Failed to sign out from the OpenID provider.",
)
return {"status": True}
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
},
headers=response.headers,
)
return JSONResponse(
status_code=200, content={"status": True}, headers=response.headers
)
############################
@@ -582,7 +662,10 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
log.error(f"Add user error: {str(err)}")
raise HTTPException(
500, detail="An internal error occurred while adding the user."
)
############################
@@ -596,7 +679,7 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email:
admin = Users.get_user_by_email(admin_email)
@@ -630,11 +713,16 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
}
@@ -645,11 +733,16 @@ class AdminConfig(BaseModel):
ENABLE_API_KEY: bool
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
API_KEY_ALLOWED_ENDPOINTS: str
ENABLE_CHANNELS: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
ENABLE_MESSAGE_RATING: bool
ENABLE_CHANNELS: bool
ENABLE_NOTES: bool
ENABLE_USER_WEBHOOKS: bool
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
RESPONSE_WATERMARK: Optional[str] = None
@router.post("/admin/config")
@@ -669,6 +762,7 @@ async def update_admin_config(
)
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -684,6 +778,17 @@ async def update_admin_config(
)
request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
form_data.PENDING_USER_OVERLAY_TITLE
)
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
form_data.PENDING_USER_OVERLAY_CONTENT
)
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
@@ -691,11 +796,16 @@ async def update_admin_config(
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
}
@@ -711,6 +821,7 @@ class LdapServerConfig(BaseModel):
search_filters: str = ""
use_tls: bool = True
certificate_path: Optional[str] = None
validate_cert: bool = True
ciphers: Optional[str] = "ALL"
@@ -728,6 +839,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@@ -750,11 +862,6 @@ async def update_ldap_server(
if not value:
raise HTTPException(400, detail=f"Required field {key} is empty")
if form_data.use_tls and not form_data.certificate_path:
raise HTTPException(
400, detail="TLS is enabled but certificate file path is missing"
)
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
request.app.state.config.LDAP_SERVER_HOST = form_data.host
request.app.state.config.LDAP_SERVER_PORT = form_data.port
@@ -768,6 +875,7 @@ async def update_ldap_server(
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
return {
@@ -782,6 +890,7 @@ async def update_ldap_server(
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}

View File

@@ -192,7 +192,7 @@ async def get_channel_messages(
############################
async def send_notification(webui_url, channel, message, active_user_ids):
async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
for user in users:
@@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
if webhook_url:
post_webhook(
name,
webhook_url,
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
{
@@ -302,6 +303,7 @@ async def post_new_message(
background_tasks.add_task(
send_notification,
request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL,
channel,
message,

View File

@@ -2,6 +2,8 @@ import json
import logging
from typing import Optional
from open_webui.socket.main import get_event_emitter
from open_webui.models.chats import (
ChatForm,
ChatImportForm,
@@ -74,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str,
page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
):
if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
)
@@ -192,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
############################
@router.get("/pinned", response_model=list[ChatResponse])
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
@@ -265,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=list[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_verified_user),
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_archived_chat_list_by_user_id(
user.id,
filter=filter,
skip=skip,
limit=limit,
)
]
return chat_list
############################
@@ -372,6 +419,107 @@ async def update_chat_by_id(
)
############################
# UpdateChatMessageById
############################
class MessageForm(BaseModel):
content: str
@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
async def update_chat_message_by_id(
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.upsert_message_to_chat_by_id_and_message_id(
id,
message_id,
{
"content": form_data.content,
},
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
},
False,
)
if event_emitter:
await event_emitter(
{
"type": "chat:message",
"data": {
"chat_id": id,
"message_id": message_id,
"content": form_data.content,
},
}
)
return ChatResponse(**chat.model_dump())
############################
# SendChatMessageEventById
############################
class EventForm(BaseModel):
type: str
data: dict
@router.post("/{id}/messages/{message_id}/event", response_model=Optional[bool])
async def send_chat_message_event_by_id(
id: str, message_id: str, form_data: EventForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(id)
if not chat:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if chat.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
event_emitter = get_event_emitter(
{
"user_id": user.id,
"chat_id": id,
"message_id": message_id,
}
)
try:
if event_emitter:
await event_emitter(form_data.model_dump())
else:
return False
return True
except:
return False
############################
# DeleteChatById
############################
@@ -476,7 +624,12 @@ async def clone_chat_by_id(
@router.post("/{id}/clone/shared", response_model=Optional[ChatResponse])
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_share_id(id)
if user.role == "admin":
chat = Chats.get_chat_by_id(id)
else:
chat = Chats.get_chat_by_share_id(id)
if chat:
updated_chat = {
**chat.chat,
@@ -530,8 +683,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
from typing import Optional
@@ -7,6 +7,8 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.config import get_config, save_config
from open_webui.config import BannerModel
from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data
router = APIRouter()
@@ -66,10 +68,86 @@ async def set_direct_connections_config(
}
############################
# ToolServers Config
############################
class ToolServerConnection(BaseModel):
url: str
path: str
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
model_config = ConfigDict(extra="allow")
class ToolServersConfigForm(BaseModel):
TOOL_SERVER_CONNECTIONS: list[ToolServerConnection]
@router.get("/tool_servers", response_model=ToolServersConfigForm)
async def get_tool_servers_config(request: Request, user=Depends(get_admin_user)):
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers", response_model=ToolServersConfigForm)
async def set_tool_servers_config(
request: Request,
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
}
@router.post("/tool_servers/verify")
async def verify_tool_servers_config(
request: Request, form_data: ToolServerConnection, user=Depends(get_admin_user)
):
"""
Verify the connection to the tool server.
"""
try:
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
url = f"{form_data.url}/{form_data.path}"
return await get_tool_server_data(token, url)
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to connect to the tool server: {str(e)}",
)
############################
# CodeInterpreterConfig
############################
class CodeInterpreterConfigForm(BaseModel):
ENABLE_CODE_EXECUTION: bool
CODE_EXECUTION_ENGINE: str
CODE_EXECUTION_JUPYTER_URL: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int]
ENABLE_CODE_INTERPRETER: bool
CODE_INTERPRETER_ENGINE: str
CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str]
@@ -77,11 +155,19 @@ class CodeInterpreterConfigForm(BaseModel):
CODE_INTERPRETER_JUPYTER_AUTH: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str]
CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str]
CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int]
@router.get("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def get_code_interpreter_config(request: Request, user=Depends(get_admin_user)):
@router.get("/code_execution", response_model=CodeInterpreterConfigForm)
async def get_code_execution_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -89,13 +175,34 @@ async def get_code_interpreter_config(request: Request, user=Depends(get_admin_u
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}
@router.post("/code_interpreter", response_model=CodeInterpreterConfigForm)
async def set_code_interpreter_config(
@router.post("/code_execution", response_model=CodeInterpreterConfigForm)
async def set_code_execution_config(
request: Request, form_data: CodeInterpreterConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_CODE_EXECUTION = form_data.ENABLE_CODE_EXECUTION
request.app.state.config.CODE_EXECUTION_ENGINE = form_data.CODE_EXECUTION_ENGINE
request.app.state.config.CODE_EXECUTION_JUPYTER_URL = (
form_data.CODE_EXECUTION_JUPYTER_URL
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH = (
form_data.CODE_EXECUTION_JUPYTER_AUTH
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
)
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT = (
form_data.CODE_EXECUTION_JUPYTER_TIMEOUT
)
request.app.state.config.ENABLE_CODE_INTERPRETER = form_data.ENABLE_CODE_INTERPRETER
request.app.state.config.CODE_INTERPRETER_ENGINE = form_data.CODE_INTERPRETER_ENGINE
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE = (
@@ -116,8 +223,18 @@ async def set_code_interpreter_config(
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD = (
form_data.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD
)
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT = (
form_data.CODE_INTERPRETER_JUPYTER_TIMEOUT
)
return {
"ENABLE_CODE_EXECUTION": request.app.state.config.ENABLE_CODE_EXECUTION,
"CODE_EXECUTION_ENGINE": request.app.state.config.CODE_EXECUTION_ENGINE,
"CODE_EXECUTION_JUPYTER_URL": request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
"CODE_EXECUTION_JUPYTER_AUTH": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH,
"CODE_EXECUTION_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN,
"CODE_EXECUTION_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD,
"CODE_EXECUTION_JUPYTER_TIMEOUT": request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
"ENABLE_CODE_INTERPRETER": request.app.state.config.ENABLE_CODE_INTERPRETER,
"CODE_INTERPRETER_ENGINE": request.app.state.config.CODE_INTERPRETER_ENGINE,
"CODE_INTERPRETER_PROMPT_TEMPLATE": request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE,
@@ -125,6 +242,7 @@ async def set_code_interpreter_config(
"CODE_INTERPRETER_JUPYTER_AUTH": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH,
"CODE_INTERPRETER_JUPYTER_AUTH_TOKEN": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN,
"CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD": request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD,
"CODE_INTERPRETER_JUPYTER_TIMEOUT": request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT,
}

View File

@@ -56,19 +56,35 @@ async def update_config(
}
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class FeedbackUserResponse(FeedbackResponse):
user: Optional[UserModel] = None
user: Optional[UserResponse] = None
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackUserResponse(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
feedback_list = []
for feedback in feedbacks:
user = Users.get_user_by_id(feedback.user_id)
feedback_list.append(
FeedbackUserResponse(
**feedback.model_dump(),
user=UserResponse(**user.model_dump()) if user else None,
)
)
for feedback in feedbacks
]
return feedback_list
@router.delete("/feedbacks/all")
@@ -80,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackModel(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
return feedbacks
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])

View File

@@ -1,21 +1,39 @@
import logging
import os
import uuid
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import Optional
from urllib.parse import quote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
UploadFile,
status,
Query,
)
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users
from open_webui.models.files import (
FileForm,
FileModel,
FileModelResponse,
Files,
)
from open_webui.models.knowledge import Knowledges
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe
from open_webui.storage.provider import Storage
from open_webui.utils.auth import get_admin_user, get_verified_user
from pydantic import BaseModel
@@ -26,6 +44,39 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# Check if the current user has access to a file through any knowledge bases the user may be in.
############################
def has_access_to_file(
file_id: Optional[str], access_type: str, user=Depends(get_verified_user)
) -> bool:
file = Files.get_file_by_id(file_id)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_access = False
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type
)
for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id:
has_access = True
break
return has_access
############################
# Upload File
############################
@@ -35,19 +86,55 @@ router = APIRouter()
def upload_file(
request: Request,
file: UploadFile = File(...),
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user),
file_metadata: dict = {},
):
log.info(f"file.content_type: {file.content_type}")
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
)
file_metadata = metadata if metadata else {}
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
file_extension = os.path.splitext(filename)[1]
# Remove the leading dot from the file extension
file_extension = file_extension[1:] if file_extension else ""
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
]
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(
f"File type {file_extension} is not allowed"
),
)
# replace filename with uuid
id = str(uuid.uuid4())
name = filename
filename = f"{id}_{filename}"
contents, file_path = Storage.upload_file(file.file, filename)
tags = {
"OpenWebUI-User-Email": user.email,
"OpenWebUI-User-Id": user.id,
"OpenWebUI-User-Name": user.name,
"OpenWebUI-File-Id": id,
}
contents, file_path = Storage.upload_file(file.file, filename, tags)
file_item = Files.insert_new_file(
user.id,
@@ -65,19 +152,40 @@ def upload_file(
}
),
)
if process:
try:
if file.content_type:
if file.content_type.startswith("audio/") or file.content_type in {
"video/webm"
}:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
try:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
log.error(f"Error processing file: {file_item.id}")
file_item = FileModelResponse(
**{
**file_item.model_dump(),
"error": str(e.detail) if hasattr(e, "detail") else str(e),
}
)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=id), user=user)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
except Exception as e:
log.exception(e)
log.error(f"Error processing file: {file_item.id}")
file_item = FileModelResponse(
**{
**file_item.model_dump(),
"error": str(e.detail) if hasattr(e, "detail") else str(e),
}
)
if file_item:
return file_item
@@ -91,7 +199,7 @@ def upload_file(
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
)
@@ -101,14 +209,62 @@ def upload_file(
@router.get("/", response_model=list[FileModelResponse])
async def list_files(user=Depends(get_verified_user)):
async def list_files(user=Depends(get_verified_user), content: bool = Query(True)):
if user.role == "admin":
files = Files.get_files()
else:
files = Files.get_files_by_user_id(user.id)
if not content:
for file in files:
if "content" in file.data:
del file.data["content"]
return files
############################
# Search Files
############################
@router.get("/search", response_model=list[FileModelResponse])
async def search_files(
filename: str = Query(
...,
description="Filename pattern to search for. Supports wildcards such as '*.txt'",
),
content: bool = Query(True),
user=Depends(get_verified_user),
):
"""
Search for files by filename with support for wildcard patterns.
"""
# Get files according to user role
if user.role == "admin":
files = Files.get_files()
else:
files = Files.get_files_by_user_id(user.id)
# Get matching files
matching_files = [
file for file in files if fnmatch(file.filename.lower(), filename.lower())
]
if not matching_files:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No files found matching the pattern.",
)
if not content:
for file in matching_files:
if "content" in file.data:
del file.data["content"]
return matching_files
############################
# Delete All Files
############################
@@ -144,7 +300,17 @@ async def delete_all_files(user=Depends(get_admin_user)):
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
return file
else:
raise HTTPException(
@@ -162,7 +328,17 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
return {"content": file.data.get("content", "")}
else:
raise HTTPException(
@@ -186,7 +362,17 @@ async def update_file_data_content_by_id(
):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
):
try:
process_file(
request,
@@ -212,9 +398,22 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_content_by_id(
id: str, user=Depends(get_verified_user), attachment: bool = Query(False)
):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
@@ -225,17 +424,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename) # RFC5987 encoding
content_type = file.meta.get("content_type")
filename = file.meta.get("name", file.filename)
encoded_filename = quote(filename)
headers = {}
if file.meta.get("content_type") not in [
"application/pdf",
"text/plain",
]:
headers = {
**headers,
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
}
return FileResponse(file_path, headers=headers)
if attachment:
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
else:
if content_type == "application/pdf" or filename.lower().endswith(
".pdf"
):
headers["Content-Disposition"] = (
f"inline; filename*=UTF-8''{encoded_filename}"
)
content_type = "application/pdf"
elif content_type != "text/plain":
headers["Content-Disposition"] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
return FileResponse(file_path, headers=headers, media_type=content_type)
else:
raise HTTPException(
@@ -259,14 +470,32 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/html")
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
file_user = Users.get_user_by_id(file.user_id)
if not file_user.role == "admin":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
log.info(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(
@@ -291,7 +520,17 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "read", user)
):
file_path = file.path
# Handle Unicode filenames
@@ -342,7 +581,18 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"
or has_access_to_file(id, "write", user)
):
# We should add Chroma cleanup here
result = Files.delete_file_by_id(id)

View File

@@ -20,11 +20,13 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -228,7 +230,19 @@ async def update_folder_is_expanded_by_id(
@router.delete("/{id}")
async def delete_folder_by_id(id: str, user=Depends(get_verified_user)):
async def delete_folder_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
chat_delete_permission = has_permission(
user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS
)
if user.role != "admin" and not chat_delete_permission:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
folder = Folders.get_folder_by_id_and_user_id(id, user.id)
if folder:
try:

View File

@@ -1,4 +1,8 @@
import os
import re
import logging
import aiohttp
from pathlib import Path
from typing import Optional
@@ -8,11 +12,22 @@ from open_webui.models.functions import (
FunctionResponse,
Functions,
)
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
from open_webui.utils.plugin import (
load_function_module_by_id,
replace_imports,
get_function_module_from_cache,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, HttpUrl
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -36,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_function_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
function_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the function"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": function_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
############################
# SyncFunctions
############################
class SyncFunctionsForm(FunctionForm):
functions: list[FunctionModel] = []
@router.post("/sync", response_model=Optional[FunctionModel])
async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
):
return Functions.sync_functions(user.id, form_data.functions)
############################
# CreateNewFunction
############################
@@ -68,7 +174,7 @@ async def create_new_function(
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
@@ -79,7 +185,7 @@ async def create_new_function(
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
log.exception(f"Failed to create a new function: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -183,7 +289,7 @@ async def update_function_by_id(
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
log.debug(updated)
function = Functions.update_function_by_id(id, updated)
@@ -256,11 +362,9 @@ async def get_function_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -284,11 +388,9 @@ async def update_function_valves_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -299,7 +401,7 @@ async def update_function_valves_by_id(
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function values by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -347,11 +449,9 @@ async def get_function_user_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
@@ -371,11 +471,9 @@ async def update_function_user_valves_by_id(
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
@@ -388,7 +486,7 @@ async def update_function_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Error updating function user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

16
backend/open_webui/routers/groups.py Normal file → Executable file
View File

@@ -1,7 +1,7 @@
import os
from pathlib import Path
from typing import Optional
import logging
from open_webui.models.users import Users
from open_webui.models.groups import (
@@ -14,7 +14,13 @@ from open_webui.models.groups import (
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -37,7 +43,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)):
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
if group:
@@ -48,7 +54,7 @@ async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)
detail=ERROR_MESSAGES.DEFAULT("Error creating group"),
)
except Exception as e:
print(e)
log.exception(f"Error creating a new group: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -94,7 +100,7 @@ async def update_group_by_id(
detail=ERROR_MESSAGES.DEFAULT("Error updating group"),
)
except Exception as e:
print(e)
log.exception(f"Error updating group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
@@ -118,7 +124,7 @@ async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DEFAULT("Error deleting group"),
)
except Exception as e:
print(e)
log.exception(f"Error deleting group {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),

View File

@@ -25,7 +25,7 @@ from pydantic import BaseModel
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR = CACHE_DIR / "image" / "generations"
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
@@ -55,6 +55,10 @@ async def get_config(request: Request, user=Depends(get_admin_user)):
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@@ -78,6 +82,11 @@ class ComfyUIConfigForm(BaseModel):
COMFYUI_WORKFLOW_NODES: list[dict]
class GeminiConfigForm(BaseModel):
GEMINI_API_BASE_URL: str
GEMINI_API_KEY: str
class ConfigForm(BaseModel):
enabled: bool
engine: str
@@ -85,6 +94,7 @@ class ConfigForm(BaseModel):
openai: OpenAIConfigForm
automatic1111: Automatic1111ConfigForm
comfyui: ComfyUIConfigForm
gemini: GeminiConfigForm
@router.post("/config/update")
@@ -103,6 +113,11 @@ async def update_config(
)
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
request.app.state.config.IMAGES_GEMINI_API_BASE_URL = (
form_data.gemini.GEMINI_API_BASE_URL
)
request.app.state.config.IMAGES_GEMINI_API_KEY = form_data.gemini.GEMINI_API_KEY
request.app.state.config.AUTOMATIC1111_BASE_URL = (
form_data.automatic1111.AUTOMATIC1111_BASE_URL
)
@@ -129,6 +144,8 @@ async def update_config(
request.app.state.config.COMFYUI_BASE_URL = (
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
)
request.app.state.config.COMFYUI_API_KEY = form_data.comfyui.COMFYUI_API_KEY
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
form_data.comfyui.COMFYUI_WORKFLOW_NODES
@@ -155,6 +172,10 @@ async def update_config(
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
},
"gemini": {
"GEMINI_API_BASE_URL": request.app.state.config.IMAGES_GEMINI_API_BASE_URL,
"GEMINI_API_KEY": request.app.state.config.IMAGES_GEMINI_API_KEY,
},
}
@@ -184,9 +205,17 @@ async def verify_url(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_IMAGE_GENERATION = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
headers = None
if request.app.state.config.COMFYUI_API_KEY:
headers = {
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
}
try:
r = requests.get(
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
headers=headers,
)
r.raise_for_status()
return True
@@ -224,6 +253,12 @@ def get_image_model(request):
if request.app.state.config.IMAGE_GENERATION_MODEL
else "dall-e-2"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
if request.app.state.config.IMAGE_GENERATION_MODEL
else "imagen-3.0-generate-002"
)
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
return (
request.app.state.config.IMAGE_GENERATION_MODEL
@@ -298,6 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
@@ -322,7 +362,7 @@ def get_models(request: Request, user=Depends(get_verified_user)):
if model_node_id:
model_list_key = None
print(workflow[model_node_id]["class_type"])
log.info(workflow[model_node_id]["class_type"])
for key in info[workflow[model_node_id]["class_type"]]["input"][
"required"
]:
@@ -411,7 +451,7 @@ def load_url_image_data(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
@@ -420,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@@ -461,7 +501,11 @@ async def image_generations(
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
**(
{}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {"response_format": "b64_json"}
),
}
# Use asyncio.to_thread for the requests.post call
@@ -478,11 +522,50 @@ async def image_generations(
images = []
for image in res["data"]:
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
if image_url := image.get("url", None):
image_data, content_type = load_url_image_data(image_url, headers)
else:
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
headers = {}
headers["Content-Type"] = "application/json"
headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY
model = get_image_model(request)
data = {
"instances": {"prompt": form_data.prompt},
"parameters": {
"sampleCount": form_data.n,
"outputOptions": {"mimeType": "image/png"},
},
}
# Use asyncio.to_thread for the requests.post call
r = await asyncio.to_thread(
requests.post,
url=f"{request.app.state.config.IMAGES_GEMINI_API_BASE_URL}/models/{model}:predict",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["predictions"]:
image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
@@ -529,9 +612,9 @@ async def image_generations(
image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image(
request,
form_data.model_dump(exclude_none=True),
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
@@ -541,7 +624,7 @@ async def image_generations(
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
):
if form_data.model:
set_image_model(form_data.model)
set_image_model(request, form_data.model)
data = {
"prompt": form_data.prompt,
@@ -582,9 +665,9 @@ async def image_generations(
image_data, content_type = load_b64_image_data(image)
url = upload_image(
request,
{**data, "info": res["info"]},
image_data,
content_type,
{**data, "info": res["info"]},
user,
)
images.append({"url": url})

View File

@@ -9,8 +9,8 @@ from open_webui.models.knowledge import (
KnowledgeResponse,
KnowledgeUserResponse,
)
from open_webui.models.files import Files, FileModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.models.files import Files, FileModel, FileMetadataResponse
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import (
process_file,
ProcessFileForm,
@@ -161,13 +161,94 @@ async def create_new_knowledge(
)
############################
# ReindexKnowledgeFiles
############################
@router.post("/reindex", response_model=bool)
async def reindex_knowledge_files(request: Request, user=Depends(get_verified_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
knowledge_bases = Knowledges.get_knowledge_bases()
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
deleted_knowledge_bases = []
for knowledge_base in knowledge_bases:
# -- Robust error handling for missing or invalid data
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
log.warning(
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
)
try:
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
deleted_knowledge_bases.append(knowledge_base.id)
except Exception as e:
log.error(
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
)
continue
try:
file_ids = knowledge_base.data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection(
collection_name=knowledge_base.id
)
except Exception as e:
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
continue # Skip, don't raise
failed_files = []
for file in files:
try:
process_file(
request,
ProcessFileForm(
file_id=file.id, collection_name=knowledge_base.id
),
user=user,
)
except Exception as e:
log.error(
f"Error processing file {file.filename} (ID: {file.id}): {str(e)}"
)
failed_files.append({"file_id": file.id, "error": str(e)})
continue
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
# Don't raise, just continue
continue
if failed_files:
log.warning(
f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}"
)
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
log.info(
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
)
return True
############################
# GetKnowledgeById
############################
class KnowledgeFilesResponse(KnowledgeResponse):
files: list[FileModel]
files: list[FileMetadataResponse]
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
@@ -183,7 +264,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -311,7 +392,7 @@ def add_file_to_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -388,7 +469,7 @@ def update_file_from_knowledge_by_id(
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -437,14 +518,24 @@ def remove_file_from_knowledge_by_id(
)
# Remove content from the vector database
VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
try:
VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
# Remove the file's collection from vector database
file_collection = f"file-{form_data.file_id}"
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
try:
# Remove the file's collection from vector database
file_collection = f"file-{form_data.file_id}"
if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection):
VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection)
except Exception as e:
log.debug("This was most likely caused by bypassing embedding processing")
log.debug(e)
pass
# Delete file from database
Files.delete_file_by_id(form_data.file_id)
@@ -460,7 +551,7 @@ def remove_file_from_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -614,7 +705,7 @@ def add_files_to_knowledge_batch(
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
log.info(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
@@ -656,7 +747,7 @@ def add_files_to_knowledge_batch(
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
files=Files.get_file_metadatas_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details,
@@ -664,5 +755,6 @@ def add_files_to_knowledge_batch(
)
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
**knowledge.model_dump(),
files=Files.get_file_metadatas_by_ids(existing_file_ids),
)

View File

@@ -4,7 +4,7 @@ import logging
from typing import Optional
from open_webui.models.memories import Memories, MemoryModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user
from open_webui.env import SRC_LOG_LEVELS
@@ -57,7 +57,9 @@ async def add_memory(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": {"created_at": memory.created_at},
}
],
@@ -82,7 +84,7 @@ async def query_memory(
):
results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
limit=form_data.k,
)
@@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
@@ -149,7 +153,9 @@ async def update_memory_by_id(
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
memory = Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content
)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
@@ -161,7 +167,7 @@ async def update_memory_by_id(
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user
memory.content, user=user
),
"metadata": {
"created_at": memory.created_at,

View File

@@ -0,0 +1,218 @@
import json
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel
from open_webui.models.users import Users, UserResponse
from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetNotes
############################
@router.get("/", response_model=list[NoteUserResponse])
async def get_notes(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
notes = [
NoteUserResponse(
**{
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
for note in Notes.get_notes_by_user_id(user.id, "write")
]
return notes
@router.get("/list", response_model=list[NoteUserResponse])
async def get_note_list(request: Request, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
notes = [
NoteUserResponse(
**{
**note.model_dump(),
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
}
)
for note in Notes.get_notes_by_user_id(user.id, "read")
]
return notes
############################
# CreateNewNote
############################
@router.post("/create", response_model=Optional[NoteModel])
async def create_new_note(
request: Request, form_data: NoteForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
try:
note = Notes.insert_new_note(form_data, user.id)
return note
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetNoteById
############################
@router.get("/{id}", response_model=Optional[NoteModel])
async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
return note
############################
# UpdateNoteById
############################
@router.post("/{id}/update", response_model=Optional[NoteModel])
async def update_note_by_id(
request: Request, id: str, form_data: NoteForm, user=Depends(get_verified_user)
):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
note = Notes.update_note_by_id(id, form_data)
return note
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteNoteById
############################
@router.delete("/{id}/delete", response_model=bool)
async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role != "admin" and not has_permission(
user.id, "features.notes", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
note = Notes.get_note_by_id(id)
if not note:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
note = Notes.delete_note_by_id(id)
return True
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)

View File

@@ -9,11 +9,18 @@ import os
import random
import re
import time
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
from aiocache import cached
import requests
from open_webui.models.users import UserModel
from open_webui.env import (
ENABLE_FORWARD_USER_INFO_HEADERS,
)
from fastapi import (
Depends,
@@ -26,7 +33,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, validator
from starlette.background import BackgroundTask
@@ -49,8 +56,9 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.constants import ERROR_MESSAGES
@@ -66,12 +74,27 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
return await response.json()
except Exception as e:
@@ -96,6 +119,7 @@ async def send_post_request(
stream: bool = True,
key: Optional[str] = None,
content_type: Optional[str] = None,
user: UserModel = None,
):
r = None
@@ -110,7 +134,18 @@ async def send_post_request(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
r.raise_for_status()
@@ -186,12 +221,26 @@ async def verify_connection(
key = form_data.key
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
f"{url}/api/version",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
detail = f"HTTP Error: {r.status}"
@@ -253,8 +302,24 @@ async def update_config(
}
@cached(ttl=3)
async def get_all_models(request: Request):
def merge_ollama_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
@@ -262,7 +327,7 @@ async def get_all_models(request: Request):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/tags"))
request_tasks.append(send_get_request(f"{url}/api/tags", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
@@ -275,7 +340,9 @@ async def get_all_models(request: Request):
key = api_config.get("key", None)
if enable:
request_tasks.append(send_get_request(f"{url}/api/tags", key))
request_tasks.append(
send_get_request(f"{url}/api/tags", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
@@ -291,7 +358,10 @@ async def get_all_models(request: Request):
), # Legacy support
)
connection_type = api_config.get("connection_type", "local")
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
model_ids = api_config.get("model_ids", [])
if len(model_ids) != 0 and "models" in response:
@@ -302,27 +372,18 @@ async def get_all_models(request: Request):
)
)
if prefix_id:
for model in response.get("models", []):
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
def merge_models_lists(model_lists):
merged_models = {}
if tags:
model["tags"] = tags
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
if connection_type:
model["connection_type"] = connection_type
models = {
"models": merge_models_lists(
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
@@ -330,6 +391,22 @@ async def get_all_models(request: Request):
)
}
try:
loaded_models = await get_ollama_loaded_models(request, user=user)
expires_map = {
m["name"]: m["expires_at"]
for m in loaded_models["models"]
if "expires_at" in m
}
for m in models["models"]:
if m["name"] in expires_map:
# Parse ISO8601 datetime with offset, get unix timestamp as int
dt = datetime.fromisoformat(expires_map[m["name"]])
m["expires_at"] = int(dt.timestamp())
except Exception as e:
log.debug(f"Failed to get loaded models: {e}")
else:
models = {"models": []}
@@ -360,7 +437,7 @@ async def get_ollama_tags(
models = []
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -370,7 +447,19 @@ async def get_ollama_tags(
r = requests.request(
method="GET",
url=f"{url}/api/tags",
headers={**({"Authorization": f"Bearer {key}"} if key else {})},
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@@ -398,24 +487,95 @@ async def get_ollama_tags(
return models
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(f"{url}/api/ps", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
models = {
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
else:
models = {"models": []}
return models
@router.get("/api/version")
@router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
if request.app.state.config.ENABLE_OLLAMA_API:
if url_idx is None:
# returns lowest version
request_tasks = [
send_get_request(
f"{url}/api/version",
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
url, {}
), # Legacy support
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(
f"{url}/api/version",
key,
)
)
responses = await asyncio.gather(*request_tasks)
responses = list(filter(lambda x: x is not None, responses))
@@ -462,35 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
return {"version": False}
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [
send_get_request(
f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
class ModelNameForm(BaseModel):
name: str
@router.post("/api/unload")
async def unload_model(
request: Request,
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
model_name = form_data.name
if not model_name:
raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
# Canonicalize model name (if not supplied with version)
if ":" not in model_name:
model_name = f"{model_name}:latest"
if model_name not in models:
raise HTTPException(
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
)
url_indices = models[model_name]["urls"]
# Send unload to ALL url_indices
results = []
errors = []
for idx in url_indices:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
)
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id and model_name.startswith(f"{prefix_id}."):
model_name = model_name[len(f"{prefix_id}.") :]
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
try:
res = await send_post_request(
url=f"{url}/api/generate",
payload=json.dumps(payload),
stream=False,
key=key,
user=user,
)
results.append({"url_idx": idx, "success": True, "response": res})
except Exception as e:
log.exception(f"Failed to unload model on node {idx}: {e}")
errors.append({"url_idx": idx, "success": False, "error": str(e)})
if len(errors) > 0:
raise HTTPException(
status_code=500,
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
)
return {"status": True}
@router.post("/api/pull")
@router.post("/api/pull/{url_idx}")
async def pull_model(
@@ -509,6 +708,7 @@ async def pull_model(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -527,7 +727,7 @@ async def push_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@@ -545,6 +745,7 @@ async def push_model(
url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -571,6 +772,7 @@ async def create_model(
url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -588,7 +790,7 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.source in models:
@@ -609,6 +811,16 @@ async def copy_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -643,7 +855,7 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@@ -665,6 +877,16 @@ async def delete_model(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
@@ -693,7 +915,7 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
if form_data.name not in models:
@@ -714,6 +936,16 @@ async def show_model_info(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -757,7 +989,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -774,8 +1006,16 @@ async def embed(
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try:
r = requests.request(
method="POST",
@@ -783,6 +1023,16 @@ async def embed(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -826,7 +1076,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -843,8 +1093,16 @@ async def embeddings(
)
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}), # Legacy support
)
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
try:
r = requests.request(
method="POST",
@@ -852,6 +1110,16 @@ async def embeddings(
headers={
"Content-Type": "application/json",
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
data=form_data.model_dump_json(exclude_none=True).encode(),
)
@@ -882,7 +1150,7 @@ class GenerateCompletionForm(BaseModel):
prompt: str
suffix: Optional[str] = None
images: Optional[list[str]] = None
format: Optional[str] = None
format: Optional[Union[dict, str]] = None
options: Optional[dict] = None
system: Optional[str] = None
template: Optional[str] = None
@@ -901,7 +1169,7 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
await get_all_models(request)
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -931,20 +1199,34 @@ async def generate_completion(
url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
class ChatMessage(BaseModel):
role: str
content: str
content: Optional[str] = None
tool_calls: Optional[list[dict]] = None
images: Optional[list[str]] = None
@validator("content", pre=True)
@classmethod
def check_at_least_one_field(cls, field_value, values, **kwargs):
# Raise an error if both 'content' and 'tool_calls' are None
if field_value is None and (
"tool_calls" not in values or values["tool_calls"] is None
):
raise ValueError(
"At least one of 'content' or 'tool_calls' must be provided"
)
return field_value
class GenerateChatCompletionForm(BaseModel):
model: str
messages: list[ChatMessage]
format: Optional[dict] = None
format: Optional[Union[dict, str]] = None
options: Optional[dict] = None
template: Optional[str] = None
stream: Optional[bool] = True
@@ -1001,13 +1283,14 @@ async def generate_chat_completion(
params = model_info.params.model_dump()
if params:
if payload.get("options") is None:
payload["options"] = {}
system = params.pop("system", None)
# Unlike OpenAI, Ollama does not support params directly in the body
payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"]
params, (payload.get("options", {}) or {})
)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@@ -1040,13 +1323,14 @@ async def generate_chat_completion(
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# payload["keep_alive"] = -1 # keep alive forever
return await send_post_request(
url=f"{url}/api/chat",
payload=json.dumps(payload),
stream=form_data.stream,
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson",
user=user,
)
@@ -1058,7 +1342,7 @@ class OpenAIChatMessageContent(BaseModel):
class OpenAIChatMessage(BaseModel):
role: str
content: Union[str, list[OpenAIChatMessageContent]]
content: Union[Optional[str], list[OpenAIChatMessageContent]]
model_config = ConfigDict(extra="allow")
@@ -1149,6 +1433,7 @@ async def generate_openai_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -1187,8 +1472,10 @@ async def generate_openai_chat_completion(
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if user.role == "user":
@@ -1227,6 +1514,7 @@ async def generate_openai_chat_completion(
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS),
user=user,
)
@@ -1240,7 +1528,7 @@ async def get_openai_models(
models = []
if url_idx is None:
model_list = await get_all_models(request)
model_list = await get_all_models(request, user=user)
models = [
{
"id": model["model"],
@@ -1341,7 +1629,9 @@ async def download_file_stream(
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(file_url, headers=headers) as response:
async with session.get(
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
@@ -1356,7 +1646,8 @@ async def download_file_stream(
if done:
file.seek(0)
hashed = calculate_sha256(file)
chunk_size = 1024 * 1024 * 2
hashed = calculate_sha256(file, chunk_size)
file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
@@ -1420,7 +1711,9 @@ async def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = os.path.join(UPLOAD_DIR, file.filename)
filename = os.path.basename(file.filename)
file_path = os.path.join(UPLOAD_DIR, filename)
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- P1: save file locally ---
@@ -1465,13 +1758,13 @@ async def upload_model(
os.remove(file_path)
# Create model in ollama
model_name, ext = os.path.splitext(file.filename)
model_name, ext = os.path.splitext(filename)
log.info(f"Created Model: {model_name}") # DEBUG
create_payload = {
"model": model_name,
# Reference the file by its original name => the uploaded blob's digest
"files": {file.filename: f"sha256:{file_hash}"},
"files": {filename: f"sha256:{file_hash}"},
}
log.info(f"Model Payload: {create_payload}") # DEBUG
@@ -1488,7 +1781,7 @@ async def upload_model(
done_msg = {
"done": True,
"blob": f"sha256:{file_hash}",
"name": file.filename,
"name": filename,
"model_created": model_name,
}
yield f"data: {json.dumps(done_msg)}\n\n"

View File

@@ -21,11 +21,13 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
ENABLE_FORWARD_USER_INFO_HEADERS,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.models.users import UserModel
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
@@ -35,6 +37,9 @@ from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from open_webui.utils.misc import (
convert_logit_bias_input_to_json,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access
@@ -51,12 +56,26 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"])
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async def send_get_request(url, key=None, user: UserModel = None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
url,
headers={
**({"Authorization": f"Bearer {key}"} if key else {}),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
return await response.json()
except Exception as e:
@@ -75,18 +94,23 @@ async def cleanup_response(
await session.close()
def openai_o1_o3_handler(payload):
def openai_o_series_handler(payload):
"""
Handle o1, o3 specific parameters
Handle "o" series specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
# Handle system role conversion based on model type
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
model_lower = payload["model"].lower()
# Legacy models use "user" role instead of "system"
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
payload["messages"][0]["role"] = "user"
else:
payload["messages"][0]["role"] = "developer"
return payload
@@ -172,7 +196,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
@@ -247,7 +271,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def get_all_models_responses(request: Request) -> list:
async def get_all_models_responses(request: Request, user: UserModel) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
@@ -271,7 +295,9 @@ async def get_all_models_responses(request: Request) -> list:
):
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@@ -291,6 +317,7 @@ async def get_all_models_responses(request: Request) -> list:
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
user=user,
)
)
else:
@@ -326,14 +353,22 @@ async def get_all_models_responses(request: Request) -> list:
), # Legacy support
)
connection_type = api_config.get("connection_type", "external")
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
if prefix_id:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
for model in (
response if isinstance(response, list) else response.get("data", [])
):
if prefix_id:
model["id"] = f"{prefix_id}.{model['id']}"
if tags:
model["tags"] = tags
if connection_type:
model["connection_type"] = connection_type
log.debug(f"get_all_models:responses() {responses}")
return responses
@@ -351,14 +386,14 @@ async def get_filtered_models(models, user):
return filtered_models
@cached(ttl=3)
async def get_all_models(request: Request) -> dict[str, list]:
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
log.info("get_all_models()")
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user=user)
def extract_data(response):
if response and "data" in response:
@@ -373,6 +408,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
@@ -380,21 +416,25 @@ async def get_all_models(request: Request) -> dict[str, list]:
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"connection_type": model.get("connection_type", "external"),
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
if (model.get("id") or model.get("name"))
and (
"api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
)
]
)
@@ -418,65 +458,79 @@ async def get_models(
}
if url_idx is None:
models = await get_all_models(request)
models = await get_all_models(request, user=user)
else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
r = None
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
response_data = await r.json()
if api_config.get("azure", False):
models = {
"data": api_config.get("model_ids", []) or [],
"object": "list",
}
else:
headers["Authorization"] = f"Bearer {key}"
# Check if we're calling OpenAI API based on the URL
if "api.openai.com" in url:
# Filter models according to the specified conditions
response_data["data"] = [
model
for model in response_data.get("data", [])
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
async with session.get(
f"{url}/models",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
models = response_data
response_data = await r.json()
# Check if we're calling OpenAI API based on the URL
if "api.openai.com" in url:
# Filter models according to the specified conditions
response_data["data"] = [
model
for model in response_data.get("data", [])
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
models = response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
@@ -498,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
url: str
key: str
config: Optional[dict] = None
@router.post("/verify")
async def verify_connection(
@@ -506,27 +562,64 @@ async def verify_connection(
url = form_data.url
key = form_data.key
api_config = form_data.config or {}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
response_data = await r.json()
return response_data
if api_config.get("azure", False):
headers["api-key"] = key
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
async with session.get(
url=f"{url}/openai/models?api-version={api_version}",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
else:
headers["Authorization"] = f"Bearer {key}"
async with session.get(
f"{url}/models",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
@@ -540,6 +633,63 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
def convert_to_azure_payload(
url,
payload: dict,
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = {
"messages",
"temperature",
"role",
"content",
"contentPart",
"contentPartImage",
"enhancements",
"dataSources",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
"max_completion_tokens",
}
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Remove temperature if not 1 for o-series models
if "temperature" in payload and payload["temperature"] != 1:
log.debug(
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
)
del payload["temperature"]
# Filter out unsupported parameters
payload = {k: v for k, v in payload.items() if k in allowed_params}
url = f"{url}/openai/deployments/{model}"
return url, payload
@router.post("/chat/completions")
async def generate_chat_completion(
request: Request,
@@ -565,8 +715,12 @@ async def generate_chat_completion(
model_id = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@@ -587,7 +741,7 @@ async def generate_chat_completion(
detail="Model not found",
)
await get_all_models(request)
await get_all_models(request, user=user)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
@@ -621,10 +775,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
if is_o1_o3:
payload = openai_o1_o3_handler(payload)
# Check if model is from "o" series
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
if is_o_series:
payload = openai_o_series_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
@@ -635,6 +789,43 @@ async def generate_chat_completion(
del payload["max_tokens"]
# Convert the modified body back to JSON
if "logit_bias" in payload:
payload["logit_bias"] = json.loads(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
headers = {
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False):
request_url, payload = convert_to_azure_payload(url, payload)
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
else:
request_url = f"{url}/chat/completions"
headers["Authorization"] = f"Bearer {key}"
payload = json.dumps(payload)
r = None
@@ -649,30 +840,10 @@ async def generate_chat_completion(
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
url=request_url,
data=payload,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
# Check if response is SSE
@@ -801,31 +972,54 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
), # Legacy support
)
r = None
session = None
streaming = False
try:
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False):
headers["api-key"] = key
headers["api-version"] = (
api_config.get("api_version", "") or "2023-03-15-preview"
)
payload = json.loads(body)
url, payload = convert_to_azure_payload(url, payload)
body = json.dumps(payload).encode()
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=f"{url}/{path}",
url=request_url,
data=body,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
r.raise_for_status()
@@ -851,7 +1045,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None:
try:
res = await r.json()
print(res)
log.error(res)
if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:

View File

@@ -9,6 +9,7 @@ from fastapi import (
status,
APIRouter,
)
import aiohttp
import os
import logging
import shutil
@@ -17,7 +18,7 @@ from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
@@ -56,96 +57,111 @@ def get_sorted_filters(model_id, models):
return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models):
async def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
try:
urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
request_data = {
"user": user,
"body": payload,
}
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
try:
async with session.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
res = (
await response.json()
if response.content_type == "application/json"
else {}
)
if "detail" in res:
raise Exception(r.status_code, res["detail"])
raise Exception(response.status, res["detail"])
except Exception as e:
log.exception(f"Connection error: {e}")
return payload
def process_pipeline_outlet_filter(request, payload, user, models):
async def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
try:
urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
r = requests.post(
if not key:
continue
headers = {"Authorization": f"Bearer {key}"}
request_data = {
"user": user,
"body": payload,
}
try:
async with session.post(
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
data = r.json()
payload = data
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
except aiohttp.ClientResponseError as e:
try:
res = r.json()
res = (
await response.json()
if "application/json" in response.content_type
else {}
)
if "detail" in res:
return Exception(r.status_code, res)
raise Exception(response.status, res)
except Exception:
pass
else:
pass
except Exception as e:
log.exception(f"Connection error: {e}")
return payload
@@ -161,7 +177,7 @@ router = APIRouter()
@router.get("/list")
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
responses = await get_all_models_responses(request)
responses = await get_all_models_responses(request, user)
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
urlIdxs = [
@@ -188,9 +204,11 @@ async def upload_pipeline(
file: UploadFile = File(...),
user=Depends(get_admin_user),
):
print("upload_pipeline", urlIdx, file.filename)
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
filename = os.path.basename(file.filename)
# Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")):
if not (filename and filename.endswith(".py")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
@@ -198,7 +216,7 @@ async def upload_pipeline(
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
file_path = os.path.join(upload_folder, filename)
r = None
try:
@@ -223,7 +241,7 @@ async def upload_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
status_code = status.HTTP_404_NOT_FOUND
@@ -274,7 +292,7 @@ async def add_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -319,7 +337,7 @@ async def delete_pipeline(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -353,7 +371,7 @@ async def get_pipelines(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -392,7 +410,7 @@ async def get_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -432,7 +450,7 @@ async def get_pipeline_valves_spec(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None
if r is not None:
@@ -474,7 +492,7 @@ async def update_pipeline_valves(
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
log.exception(f"Connection error: {e}")
detail = None

File diff suppressed because it is too large Load Diff

View File

@@ -20,6 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
@@ -182,35 +183,28 @@ async def generate_title(
else:
template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
messages = form_data["messages"]
# Remove reasoning details from the messages
for message in messages:
message["content"] = re.sub(
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
"",
message["content"],
flags=re.S,
).strip()
content = title_generation_template(
template,
messages,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
max_tokens = (
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 1000}
if models[task_model_id]["owned_by"] == "ollama"
{"max_tokens": max_tokens}
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 1000,
"max_completion_tokens": max_tokens,
}
),
"metadata": {
@@ -221,6 +215,12 @@ async def generate_title(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -290,6 +290,12 @@ async def generate_chat_tags(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -356,6 +362,12 @@ async def generate_image_prompt(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -433,6 +445,12 @@ async def generate_queries(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -514,6 +532,12 @@ async def generate_autocompletion(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -571,7 +595,7 @@ async def generate_emoji(
"stream": False,
**(
{"max_tokens": 4}
if models[task_model_id]["owned_by"] == "ollama"
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 4,
}
@@ -584,6 +608,12 @@ async def generate_emoji(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
@@ -613,17 +643,6 @@ async def generate_moa_response(
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(f"generating MOA model {task_model_id} for user {user.email} ")
template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
content = moa_response_generation_template(
@@ -633,7 +652,7 @@ async def generate_moa_response(
)
payload = {
"model": task_model_id,
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": form_data.get("stream", False),
"metadata": {
@@ -644,6 +663,12 @@ async def generate_moa_response(
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:

View File

@@ -1,5 +1,10 @@
import logging
from pathlib import Path
from typing import Optional
import time
import re
import aiohttp
from pydantic import BaseModel, HttpUrl
from open_webui.models.tools import (
ToolForm,
@@ -8,13 +13,20 @@ from open_webui.models.tools import (
ToolUserResponse,
Tools,
)
from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.tools import get_tools_specs
from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@@ -25,11 +37,51 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(user=Depends(get_verified_user)):
if user.role == "admin":
tools = Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "read")
async def get_tools(request: Request, user=Depends(get_verified_user)):
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools()
for server in request.app.state.TOOL_SERVERS:
tools.append(
ToolUserResponse(
**{
"id": f"server:{server['idx']}",
"user_id": f"server:{server['idx']}",
"name": server.get("openapi", {})
.get("info", {})
.get("title", "Tool Server"),
"meta": {
"description": server.get("openapi", {})
.get("info", {})
.get("description", ""),
},
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
server["idx"]
]
.get("config", {})
.get("access_control", None),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role != "admin":
tools = [
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control)
]
return tools
@@ -47,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
return tools
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_tool_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
tool_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the tool"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": tool_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
############################
# ExportTools
############################
@@ -89,18 +216,18 @@ async def create_new_tools(
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
tools_module, frontmatter = load_tools_module_by_id(
tool_module, frontmatter = load_tool_module_by_id(
form_data.id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = tools_module
TOOLS[form_data.id] = tool_module
specs = get_tools_specs(TOOLS[form_data.id])
specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
if tools:
@@ -111,7 +238,7 @@ async def create_new_tools(
detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
)
except Exception as e:
print(e)
log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@@ -178,22 +305,20 @@ async def update_tools_by_id(
try:
form_data.content = replace_imports(form_data.content)
tools_module, frontmatter = load_tools_module_by_id(
id, content=form_data.content
)
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[id] = tools_module
TOOLS[id] = tool_module
specs = get_tools_specs(TOOLS[id])
specs = get_tool_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
if tools:
@@ -284,7 +409,7 @@ async def get_tools_valves_spec_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"):
@@ -327,7 +452,7 @@ async def update_tools_valves_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"):
@@ -343,7 +468,7 @@ async def update_tools_valves_by_id(
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update tool valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),
@@ -383,7 +508,7 @@ async def get_tools_user_valves_spec_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
@@ -407,7 +532,7 @@ async def update_tools_user_valves_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tools_module_by_id(id)
tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
@@ -421,7 +546,7 @@ async def update_tools_user_valves_by_id(
)
return user_valves.model_dump()
except Exception as e:
print(e)
log.exception(f"Failed to update user valves by id {id}: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(str(e)),

View File

@@ -2,9 +2,11 @@ import logging
from typing import Optional
from open_webui.models.auths import Auths
from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.users import (
UserModel,
UserListResponse,
UserRoleUpdateForm,
Users,
UserSettings,
@@ -17,7 +19,10 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
from open_webui.utils.access_control import get_permissions, has_permission
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -29,13 +34,38 @@ router = APIRouter()
############################
@router.get("/", response_model=list[UserModel])
PAGE_ITEM_COUNT = 30
@router.get("/", response_model=UserListResponse)
async def get_users(
skip: Optional[int] = None,
limit: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user),
):
return Users.get_users(skip, limit)
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserListResponse)
async def get_all_users(
user=Depends(get_admin_user),
):
return Users.get_users()
############################
@@ -45,7 +75,7 @@ async def get_users(
@router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
return Groups.get_groups_by_member_id(user.id)
############################
@@ -54,8 +84,12 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions")
async def get_user_permissisions(user=Depends(get_verified_user)):
return Users.get_user_groups(user.id)
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return user_permissions
############################
@@ -68,32 +102,52 @@ class WorkspacePermissions(BaseModel):
tools: bool = False
class SharingPermissions(BaseModel):
public_models: bool = True
public_knowledge: bool = True
public_prompts: bool = True
public_tools: bool = True
class ChatPermissions(BaseModel):
controls: bool = True
file_upload: bool = True
delete: bool = True
edit: bool = True
share: bool = True
export: bool = True
stt: bool = True
tts: bool = True
call: bool = True
multiple_models: bool = True
temporary: bool = True
temporary_enforced: bool = False
class FeaturesPermissions(BaseModel):
direct_tool_servers: bool = False
web_search: bool = True
image_generation: bool = True
code_interpreter: bool = True
notes: bool = True
class UserPermissions(BaseModel):
workspace: WorkspacePermissions
sharing: SharingPermissions
chat: ChatPermissions
features: FeaturesPermissions
@router.get("/default/permissions", response_model=UserPermissions)
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
return {
"workspace": WorkspacePermissions(
**request.app.state.config.USER_PERMISSIONS.get("workspace", {})
),
"sharing": SharingPermissions(
**request.app.state.config.USER_PERMISSIONS.get("sharing", {})
),
"chat": ChatPermissions(
**request.app.state.config.USER_PERMISSIONS.get("chat", {})
),
@@ -104,7 +158,7 @@ async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
@router.post("/default/permissions")
async def update_user_permissions(
async def update_default_user_permissions(
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
):
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
@@ -151,9 +205,22 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
request: Request, form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_settings_by_id(user.id, form_data.model_dump())
updated_user_settings = form_data.model_dump()
if (
user.role != "admin"
and "toolServers" in updated_user_settings.get("ui").keys()
and not has_permission(
user.id,
"features.direct_tool_servers",
request.app.state.config.USER_PERMISSIONS,
)
):
# If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
if user:
return user.settings
else:
@@ -263,6 +330,21 @@ async def update_user_by_id(
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
):
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id and session_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
user = Users.get_user_by_id(user_id)
if user:
@@ -310,6 +392,21 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
# Prevent deletion of the primary admin user
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
@@ -321,6 +418,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
# Prevent self-deletion
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,

View File

@@ -1,48 +1,84 @@
import black
import logging
import markdown
from open_webui.models.chats import ChatTitleMessagesForm
from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from pydantic import BaseModel
from starlette.responses import FileResponse
from open_webui.utils.misc import get_gravatar_url
from open_webui.utils.pdf_generator import PDFGenerator
from open_webui.utils.auth import get_admin_user
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
@router.get("/gravatar")
async def get_gravatar(
email: str,
):
async def get_gravatar(email: str, user=Depends(get_verified_user)):
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
class CodeForm(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
return {"code": form_data.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/code/execute")
async def execute_code(
request: Request, form_data: CodeForm, user=Depends(get_verified_user)
):
if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
output = await execute_code_jupyter(
request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
form_data.code,
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
else None
),
(
request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
else None
),
request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
)
return output
else:
raise HTTPException(
status_code=400,
detail="Code execution engine not supported",
)
class MarkdownForm(BaseModel):
md: str
@router.post("/markdown")
async def get_html_from_markdown(
form_data: MarkdownForm,
form_data: MarkdownForm, user=Depends(get_verified_user)
):
return {"html": markdown.markdown(form_data.md)}
@@ -54,7 +90,7 @@ class ChatForm(BaseModel):
@router.post("/pdf")
async def download_chat_as_pdf(
form_data: ChatTitleMessagesForm,
form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
):
try:
pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
@@ -65,7 +101,7 @@ async def download_chat_as_pdf(
headers={"Content-Disposition": "attachment;filename=chat.pdf"},
)
except Exception as e:
print(e)
log.exception(f"Error generating PDF: {e}")
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -3,15 +3,23 @@ import socketio
import logging
import sys
import time
from redis import asyncio as aioredis
from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
from open_webui.utils.redis import (
get_sentinels_from_env,
get_sentinel_url_from_env,
)
from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS,
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock
@@ -28,7 +36,14 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"])
if WEBSOCKET_MANAGER == "redis":
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager(
get_sentinel_url_from_env(
WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
)
else:
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
@@ -54,14 +69,30 @@ TIMEOUT_DURATION = 3
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
redis_sentinels = get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
SESSION_POOL = RedisDict(
"open-webui:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USER_POOL = RedisDict(
"open-webui:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
USAGE_POOL = RedisDict(
"open-webui:usage_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
)
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=TIMEOUT_DURATION * 2,
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
redis_sentinels=redis_sentinels,
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock
@@ -128,18 +159,19 @@ def get_models_in_use():
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Record the timestamp for the last update
current_time = int(time.time())
if sid in SESSION_POOL:
model_id = data["model"]
# Record the timestamp for the last update
current_time = int(time.time())
# Store the new usage data and task
USAGE_POOL[model_id] = {
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
sid: {"updated_at": current_time},
}
# Store the new usage data and task
USAGE_POOL[model_id] = {
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
sid: {"updated_at": current_time},
}
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
@sio.event
@@ -247,7 +279,8 @@ async def channel_events(sid, data):
@sio.on("user-list")
async def user_list(sid):
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
if sid in SESSION_POOL:
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
@sio.event
@@ -268,15 +301,23 @@ async def disconnect(sid):
# print(f"Unknown session ID {sid} disconnected")
def get_event_emitter(request_info):
def get_event_emitter(request_info, update_db=True):
async def __event_emitter__(event_data):
user_id = request_info["user_id"]
session_ids = list(
set(USER_POOL.get(user_id, []) + [request_info["session_id"]])
set(
USER_POOL.get(user_id, [])
+ (
[request_info.get("session_id")]
if request_info.get("session_id")
else []
)
)
)
for session_id in session_ids:
await sio.emit(
emit_tasks = [
sio.emit(
"chat-events",
{
"chat_id": request_info.get("chat_id", None),
@@ -285,41 +326,47 @@ def get_event_emitter(request_info):
},
to=session_id,
)
for session_id in session_ids
]
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
await asyncio.gather(*emit_tasks)
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
if update_db:
if "type" in event_data and event_data["type"] == "status":
Chats.add_message_status_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
event_data.get("data", {}),
)
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
if "type" in event_data and event_data["type"] == "message":
message = Chats.get_message_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if message:
content = message.get("content", "")
content += event_data.get("data", {}).get("content", "")
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
if "type" in event_data and event_data["type"] == "replace":
content = event_data.get("data", {}).get("content", "")
Chats.upsert_message_to_chat_by_id_and_message_id(
request_info["chat_id"],
request_info["message_id"],
{
"content": content,
},
)
return __event_emitter__

View File

@@ -1,15 +1,17 @@
import json
import redis
import uuid
from open_webui.utils.redis import get_redis_connection
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs):
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
)
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
@@ -31,9 +33,11 @@ class RedisLock:
class RedisDict:
def __init__(self, name, redis_url):
def __init__(self, name, redis_url, redis_sentinels=[]):
self.name = name
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
)
def __setitem__(self, key, value):
serialized_value = json.dumps(value)

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

View File

@@ -269,11 +269,6 @@ tbody + tbody {
margin-bottom: 0;
}
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
.markdown-section p + ul {
margin-top: 0;
}
/* List item styles */
.markdown-section li {
padding: 2px;

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 14 KiB

View File

View File

@@ -0,0 +1,21 @@
{
"name": "Open WebUI",
"short_name": "WebUI",
"icons": [
{
"src": "/static/web-app-manifest-192x192.png",
"sizes": "192x192",
"type": "image/png",
"purpose": "maskable"
},
{
"src": "/static/web-app-manifest-512x512.png",
"sizes": "512x512",
"type": "image/png",
"purpose": "maskable"
}
],
"theme_color": "#ffffff",
"background_color": "#ffffff",
"display": "standalone"
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

View File

@@ -9308,5 +9308,3 @@
.json-schema-2020-12__title:first-of-type {
font-size: 16px;
}
/*# sourceMappingURL=swagger-ui.css.map*/

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

View File

@@ -1,10 +1,13 @@
import os
import shutil
import json
import logging
import re
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
from typing import BinaryIO, Tuple, Dict
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from open_webui.config import (
S3_ACCESS_KEY_ID,
@@ -13,14 +16,28 @@ from open_webui.config import (
S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
S3_USE_ACCELERATE_ENDPOINT,
S3_ADDRESSING_STYLE,
S3_ENABLE_TAGGING,
GCS_BUCKET_NAME,
GOOGLE_APPLICATION_CREDENTIALS_JSON,
AZURE_STORAGE_ENDPOINT,
AZURE_STORAGE_CONTAINER_NAME,
AZURE_STORAGE_KEY,
STORAGE_PROVIDER,
UPLOAD_DIR,
)
from google.cloud import storage
from google.cloud.exceptions import GoogleCloudError, NotFound
from open_webui.constants import ERROR_MESSAGES
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azure.core.exceptions import ResourceNotFoundError
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
class StorageProvider(ABC):
@@ -29,7 +46,9 @@ class StorageProvider(ABC):
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
pass
@abstractmethod
@@ -43,7 +62,9 @@ class StorageProvider(ABC):
class LocalStorageProvider(StorageProvider):
@staticmethod
def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -65,7 +86,7 @@ class LocalStorageProvider(StorageProvider):
if os.path.isfile(file_path):
os.remove(file_path)
else:
print(f"File {file_path} not found in local storage.")
log.warning(f"File {file_path} not found in local storage.")
@staticmethod
def delete_all_files() -> None:
@@ -79,32 +100,74 @@ class LocalStorageProvider(StorageProvider):
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
log.exception(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"Directory {UPLOAD_DIR} not found in local storage.")
log.warning(f"Directory {UPLOAD_DIR} not found in local storage.")
class S3StorageProvider(StorageProvider):
def __init__(self):
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config = Config(
s3={
"use_accelerate_endpoint": S3_USE_ACCELERATE_ENDPOINT,
"addressing_style": S3_ADDRESSING_STYLE,
},
)
# If access key and secret are provided, use them for authentication
if S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY:
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=config,
)
else:
# If no explicit credentials are provided, fall back to default AWS credentials
# This supports workload identity (IAM roles for EC2, EKS, etc.)
self.s3_client = boto3.client(
"s3",
region_name=S3_REGION_NAME,
endpoint_url=S3_ENDPOINT_URL,
config=config,
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
@staticmethod
def sanitize_tag_value(s: str) -> str:
"""Only include S3 allowed characters."""
return re.sub(r"[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]", "", s)
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
_, file_path = LocalStorageProvider.upload_file(file, filename, tags)
s3_key = os.path.join(self.key_prefix, filename)
try:
s3_key = os.path.join(self.key_prefix, filename)
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
if S3_ENABLE_TAGGING and tags:
sanitized_tags = {
self.sanitize_tag_value(k): self.sanitize_tag_value(v)
for k, v in tags.items()
}
tagging = {
"TagSet": [
{"Key": k, "Value": v} for k, v in sanitized_tags.items()
]
}
self.s3_client.put_object_tagging(
Bucket=self.bucket_name,
Key=s3_key,
Tagging=tagging,
)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + s3_key,
f"s3://{self.bucket_name}/{s3_key}",
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
@@ -172,9 +235,11 @@ class GCSStorageProvider(StorageProvider):
self.gcs_client = storage.Client()
self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to GCS storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename)
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
try:
blob = self.bucket.blob(filename)
blob.upload_from_filename(file_path)
@@ -221,6 +286,76 @@ class GCSStorageProvider(StorageProvider):
LocalStorageProvider.delete_all_files()
class AzureStorageProvider(StorageProvider):
def __init__(self):
self.endpoint = AZURE_STORAGE_ENDPOINT
self.container_name = AZURE_STORAGE_CONTAINER_NAME
storage_key = AZURE_STORAGE_KEY
if storage_key:
# Configure using the Azure Storage Account Endpoint and Key
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=storage_key
)
else:
# Configure using the Azure Storage Account Endpoint and DefaultAzureCredential
# If the key is not configured, then the DefaultAzureCredential will be used to support Managed Identity authentication
self.blob_service_client = BlobServiceClient(
account_url=self.endpoint, credential=DefaultAzureCredential()
)
self.container_client = self.blob_service_client.get_container_client(
self.container_name
)
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to Azure Blob Storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
try:
blob_client = self.container_client.get_blob_client(filename)
blob_client.upload_blob(contents, overwrite=True)
return contents, f"{self.endpoint}/{self.container_name}/{filename}"
except Exception as e:
raise RuntimeError(f"Error uploading file to Azure Blob Storage: {e}")
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
local_file_path = f"{UPLOAD_DIR}/{filename}"
blob_client = self.container_client.get_blob_client(filename)
with open(local_file_path, "wb") as download_file:
download_file.write(blob_client.download_blob().readall())
return local_file_path
except ResourceNotFoundError as e:
raise RuntimeError(f"Error downloading file from Azure Blob Storage: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from Azure Blob Storage."""
try:
filename = file_path.split("/")[-1]
blob_client = self.container_client.get_blob_client(filename)
blob_client.delete_blob()
except ResourceNotFoundError as e:
raise RuntimeError(f"Error deleting file from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_file(file_path)
def delete_all_files(self) -> None:
"""Handles deletion of all files from Azure Blob Storage."""
try:
blobs = self.container_client.list_blobs()
for blob in blobs:
self.container_client.delete_blob(blob.name)
except Exception as e:
raise RuntimeError(f"Error deleting all files from Azure Blob Storage: {e}")
# Always delete from local storage
LocalStorageProvider.delete_all_files()
def get_storage_provider(storage_provider: str):
if storage_provider == "local":
Storage = LocalStorageProvider()
@@ -228,6 +363,8 @@ def get_storage_provider(storage_provider: str):
Storage = S3StorageProvider()
elif storage_provider == "gcs":
Storage = GCSStorageProvider()
elif storage_provider == "azure":
Storage = AzureStorageProvider()
else:
raise RuntimeError(f"Unsupported storage provider: {storage_provider}")
return Storage

View File

@@ -5,16 +5,23 @@ from uuid import uuid4
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
def cleanup_task(task_id: str):
def cleanup_task(task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
if id and task_id in chat_tasks.get(id, []):
chat_tasks[id].remove(task_id)
if not chat_tasks[id]: # If no tasks left for this ID, remove the entry
chat_tasks.pop(id, None)
def create_task(coroutine):
def create_task(coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -22,9 +29,15 @@ def create_task(coroutine):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id))
task.add_done_callback(lambda t: cleanup_task(task_id, id))
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
if chat_tasks.get(id):
chat_tasks[id].append(task_id)
else:
chat_tasks[id] = [task_id]
return task_id, task
@@ -42,6 +55,13 @@ def list_tasks():
return list(tasks.keys())
def list_task_ids_by_chat_id(id):
"""
List all tasks associated with a specific ID.
"""
return chat_tasks.get(id, [])
async def stop_task(task_id: str):
"""
Cancel a running task and remove it from the global task list.

View File

@@ -7,6 +7,8 @@ from moto import mock_aws
from open_webui.storage import provider
from gcp_storage_emulator.server import create_server
from google.cloud import storage
from azure.storage.blob import BlobServiceClient, ContainerClient, BlobClient
from unittest.mock import MagicMock
def mock_upload_dir(monkeypatch, tmp_path):
@@ -22,6 +24,7 @@ def test_imports():
provider.LocalStorageProvider
provider.S3StorageProvider
provider.GCSStorageProvider
provider.AzureStorageProvider
provider.Storage
@@ -32,6 +35,8 @@ def test_get_storage_provider():
assert isinstance(Storage, provider.S3StorageProvider)
Storage = provider.get_storage_provider("gcs")
assert isinstance(Storage, provider.GCSStorageProvider)
Storage = provider.get_storage_provider("azure")
assert isinstance(Storage, provider.AzureStorageProvider)
with pytest.raises(RuntimeError):
provider.get_storage_provider("invalid")
@@ -48,6 +53,7 @@ def test_class_instantiation():
provider.LocalStorageProvider()
provider.S3StorageProvider()
provider.GCSStorageProvider()
provider.AzureStorageProvider()
class TestLocalStorageProvider:
@@ -181,6 +187,17 @@ class TestS3StorageProvider:
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
def test_init_without_credentials(self, monkeypatch):
"""Test that S3StorageProvider can initialize without explicit credentials."""
# Temporarily unset the environment variables
monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
# Should not raise an exception
storage = provider.S3StorageProvider()
assert storage.s3_client is not None
assert storage.bucket_name == provider.S3_BUCKET_NAME
class TestGCSStorageProvider:
Storage = provider.GCSStorageProvider()
@@ -272,3 +289,147 @@ class TestGCSStorageProvider:
assert not (upload_dir / self.filename_extra).exists()
assert self.Storage.bucket.get_blob(self.filename) == None
assert self.Storage.bucket.get_blob(self.filename_extra) == None
class TestAzureStorageProvider:
def __init__(self):
super().__init__()
@pytest.fixture(scope="class")
def setup_storage(self, monkeypatch):
# Create mock Blob Service Client and related clients
mock_blob_service_client = MagicMock()
mock_container_client = MagicMock()
mock_blob_client = MagicMock()
# Set up return values for the mock
mock_blob_service_client.get_container_client.return_value = (
mock_container_client
)
mock_container_client.get_blob_client.return_value = mock_blob_client
# Monkeypatch the Azure classes to return our mocks
monkeypatch.setattr(
azure.storage.blob,
"BlobServiceClient",
lambda *args, **kwargs: mock_blob_service_client,
)
monkeypatch.setattr(
azure.storage.blob,
"ContainerClient",
lambda *args, **kwargs: mock_container_client,
)
monkeypatch.setattr(
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
)
self.Storage = provider.AzureStorageProvider()
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
self.Storage.container_name = "my-container"
self.file_content = b"test content"
self.filename = "test.txt"
self.filename_extra = "test_extra.txt"
self.file_bytesio_empty = io.BytesIO()
# Apply mocks to the Storage instance
self.Storage.blob_service_client = mock_blob_service_client
self.Storage.container_client = mock_container_client
def test_upload_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
# Simulate an error when container does not exist
self.Storage.container_client.get_blob_client.side_effect = Exception(
"Container does not exist"
)
with pytest.raises(Exception):
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Reset side effect and create container
self.Storage.container_client.get_blob_client.side_effect = None
self.Storage.create_container()
contents, azure_file_path = self.Storage.upload_file(
io.BytesIO(self.file_content), self.filename
)
# Assertions
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
self.Storage.container_client.get_blob_client().upload_blob.assert_called_once_with(
self.file_content, overwrite=True
)
assert contents == self.file_content
assert (
azure_file_path
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
with pytest.raises(ValueError):
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
def test_get_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock upload behavior
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock blob download behavior
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
self.file_content
)
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
file_path = self.Storage.get_file(file_url)
assert file_path == str(upload_dir / self.filename)
assert (upload_dir / self.filename).exists()
assert (upload_dir / self.filename).read_bytes() == self.file_content
def test_delete_file(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file upload
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
# Mock deletion
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
self.Storage.delete_file(file_url)
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
assert not (upload_dir / self.filename).exists()
def test_delete_all_files(self, monkeypatch, tmp_path):
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
self.Storage.create_container()
# Mock file uploads
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
# Mock listing and deletion behavior
self.Storage.container_client.list_blobs.return_value = [
{"name": self.filename},
{"name": self.filename_extra},
]
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
self.Storage.delete_all_files()
self.Storage.container_client.list_blobs.assert_called_once()
self.Storage.container_client.get_blob_client().delete_blob.assert_any_call()
assert not (upload_dir / self.filename).exists()
assert not (upload_dir / self.filename_extra).exists()
def test_get_file_not_found(self, monkeypatch):
self.Storage.create_container()
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
# Mock behavior to raise an error for missing blobs
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
Exception("Blob not found")
)
with pytest.raises(Exception, match="Blob not found"):
self.Storage.get_file(file_url)

View File

@@ -0,0 +1,283 @@
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from enum import Enum
import re
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
MutableMapping,
Optional,
cast,
)
import uuid
from asgiref.typing import (
ASGI3Application,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendCallable,
ASGISendEvent,
Scope as ASGIScope,
)
from loguru import logger
from starlette.requests import Request
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
from open_webui.models.users import UserModel
if TYPE_CHECKING:
from loguru import Logger
@dataclass(frozen=True)
class AuditLogEntry:
# `Metadata` audit level properties
id: str
user: Optional[dict[str, Any]]
audit_level: str
verb: str
request_uri: str
user_agent: Optional[str] = None
source_ip: Optional[str] = None
# `Request` audit level properties
request_object: Any = None
# `Request Response` level
response_object: Any = None
response_status_code: Optional[int] = None
class AuditLevel(str, Enum):
NONE = "NONE"
METADATA = "METADATA"
REQUEST = "REQUEST"
REQUEST_RESPONSE = "REQUEST_RESPONSE"
class AuditLogger:
"""
A helper class that encapsulates audit logging functionality. It uses Logurus logger with an auditable binding to ensure that audit log entries are filtered correctly.
Parameters:
logger (Logger): An instance of Logurus logger.
"""
def __init__(self, logger: "Logger"):
self.logger = logger.bind(auditable=True)
def write(
self,
audit_entry: AuditLogEntry,
*,
log_level: str = "INFO",
extra: Optional[dict] = None,
):
entry = asdict(audit_entry)
if extra:
entry["extra"] = extra
self.logger.log(
log_level,
"",
**entry,
)
class AuditContext:
"""
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
Attributes:
request_body (bytearray): Accumulated request payload.
response_body (bytearray): Accumulated response payload.
max_body_size (int): Maximum number of bytes to capture.
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
"""
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
self.request_body = bytearray()
self.response_body = bytearray()
self.max_body_size = max_body_size
self.metadata: Dict[str, Any] = {}
def add_request_chunk(self, chunk: bytes):
if len(self.request_body) < self.max_body_size:
self.request_body.extend(
chunk[: self.max_body_size - len(self.request_body)]
)
def add_response_chunk(self, chunk: bytes):
if len(self.response_body) < self.max_body_size:
self.response_body.extend(
chunk[: self.max_body_size - len(self.response_body)]
)
class AuditLoggingMiddleware:
"""
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
"""
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
def __init__(
self,
app: ASGI3Application,
*,
excluded_paths: Optional[list[str]] = None,
max_body_size: int = MAX_BODY_LOG_SIZE,
audit_level: AuditLevel = AuditLevel.NONE,
) -> None:
self.app = app
self.audit_logger = AuditLogger(logger)
self.excluded_paths = excluded_paths or []
self.max_body_size = max_body_size
self.audit_level = audit_level
async def __call__(
self,
scope: ASGIScope,
receive: ASGIReceiveCallable,
send: ASGISendCallable,
) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)
request = Request(scope=cast(MutableMapping, scope))
if self._should_skip_auditing(request):
return await self.app(scope, receive, send)
async with self._audit_context(request) as context:
async def send_wrapper(message: ASGISendEvent) -> None:
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
await self._capture_response(message, context)
await send(message)
original_receive = receive
async def receive_wrapper() -> ASGIReceiveEvent:
nonlocal original_receive
message = await original_receive()
if self.audit_level in (
AuditLevel.REQUEST,
AuditLevel.REQUEST_RESPONSE,
):
await self._capture_request(message, context)
return message
await self.app(scope, receive_wrapper, send_wrapper)
@asynccontextmanager
async def _audit_context(
self, request: Request
) -> AsyncGenerator[AuditContext, None]:
"""
async context manager that ensures that an audit log entry is recorded after the request is processed.
"""
context = AuditContext()
try:
yield context
finally:
await self._log_audit_entry(request, context)
async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
auth_header = request.headers.get("Authorization")
try:
user = get_current_user(
request, None, get_http_authorization_cred(auth_header)
)
return user
except Exception as e:
logger.debug(f"Failed to get authenticated user: {str(e)}")
return None
def _should_skip_auditing(self, request: Request) -> bool:
if (
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
or AUDIT_LOG_LEVEL == "NONE"
):
return True
ALWAYS_LOG_ENDPOINTS = {
"/api/v1/auths/signin",
"/api/v1/auths/signout",
"/api/v1/auths/signup",
}
path = request.url.path.lower()
for endpoint in ALWAYS_LOG_ENDPOINTS:
if path.startswith(endpoint):
return False # Do NOT skip logging for auth endpoints
# Skip logging if the request is not authenticated
if not request.headers.get("authorization"):
return True
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
pattern = re.compile(
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
)
if pattern.match(request.url.path):
return True
return False
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
if message["type"] == "http.request":
body = message.get("body", b"")
context.add_request_chunk(body)
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
if message["type"] == "http.response.start":
context.metadata["response_status_code"] = message["status"]
elif message["type"] == "http.response.body":
body = message.get("body", b"")
context.add_response_chunk(body)
async def _log_audit_entry(self, request: Request, context: AuditContext):
try:
user = await self._get_authenticated_user(request)
user = (
user.model_dump(include={"id", "name", "email", "role"}) if user else {}
)
request_body = context.request_body.decode("utf-8", errors="replace")
response_body = context.response_body.decode("utf-8", errors="replace")
# Redact sensitive information
if "password" in request_body:
request_body = re.sub(
r'"password":\s*"(.*?)"',
'"password": "********"',
request_body,
)
entry = AuditLogEntry(
id=str(uuid.uuid4()),
user=user,
audit_level=self.audit_level.value,
verb=request.method,
request_uri=str(request.url),
response_status_code=context.metadata.get("response_status_code", None),
source_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
request_object=request_body,
response_object=response_body,
)
self.audit_logger.write(entry)
except Exception as e:
logger.error(f"Failed to log audit entry: {str(e)}")

View File

@@ -1,21 +1,39 @@
import logging
import uuid
import jwt
import base64
import hmac
import hashlib
import requests
import os
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta
import pytz
from pytz import UTC
from typing import Optional, Union, List, Dict
from opentelemetry import trace
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
from open_webui.env import (
WEBUI_SECRET_KEY,
TRUSTED_SIGNATURE_KEY,
STATIC_DIR,
SRC_LOG_LEVELS,
)
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
logging.getLogger("passlib").setLevel(logging.ERROR)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OAUTH"])
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
@@ -24,6 +42,67 @@ ALGORITHM = "HS256"
# Auth Utils
##############
def verify_signature(payload: str, signature: str) -> bool:
"""
Verifies the HMAC signature of the received payload.
"""
try:
expected_signature = base64.b64encode(
hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
).decode()
# Compare securely to prevent timing attacks
return hmac.compare_digest(expected_signature, signature)
except Exception:
return False
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
log.error(f"Invalid path: {path}")
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
def get_license_data(app, key):
if key:
try:
res = requests.post(
"https://api.openwebui.com/api/v1/license/",
json={"key": key, "version": "1"},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
for k, v in payload.items():
if k == "resources":
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "count":
setattr(app.state, "USER_COUNT", v)
elif k == "name":
setattr(app.state, "WEBUI_NAME", v)
elif k == "metadata":
setattr(app.state, "LICENSE_METADATA", v)
return True
else:
log.error(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}")
return False
bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -66,16 +145,19 @@ def create_api_key():
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str):
def get_http_authorization_cred(auth_header: Optional[str]):
if not auth_header:
return None
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except Exception:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
return None
def get_current_user(
request: Request,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
token = None
@@ -104,12 +186,27 @@ def get_current_user(
).split(",")
]
if request.url.path not in allowed_paths:
# Check if the request path matches any allowed endpoint.
if not any(
request.url.path == allowed
or request.url.path.startswith(allowed + "/")
for allowed in allowed_paths
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
return get_current_user_by_api_key(token)
user = get_current_user_by_api_key(token)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
return user
# auth by jwt token
try:
@@ -128,7 +225,18 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "jwt")
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
return user
else:
raise HTTPException(
@@ -146,6 +254,14 @@ def get_current_user_by_api_key(api_key: str):
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
return user

View File

@@ -40,7 +40,10 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.models import get_all_models, check_model_access
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
@@ -66,7 +69,7 @@ async def generate_direct_chat_completion(
user: Any,
models: dict,
):
print("generate_direct_chat_completion")
log.info("generate_direct_chat_completion")
metadata = form_data.pop("metadata", {})
@@ -103,7 +106,7 @@ async def generate_direct_chat_completion(
}
)
print("res", res)
log.info(f"res: {res}")
if res.get("status", False):
# Define a generator to stream responses
@@ -149,7 +152,7 @@ async def generate_direct_chat_completion(
}
)
if "error" in res:
if "error" in res and res["error"]:
raise Exception(res["error"])
return res
@@ -186,12 +189,6 @@ async def generate_chat_completion(
if model_id not in models:
raise Exception("Model not found")
# Process the form_data through the pipeline
try:
form_data = process_pipeline_inlet_filter(request, form_data, user, models)
except Exception as e:
raise e
model = models[model_id]
if getattr(request.state, "direct", False):
@@ -206,7 +203,7 @@ async def generate_chat_completion(
except Exception as e:
raise e
if model["owned_by"] == "arena":
if model.get("owned_by") == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
@@ -259,7 +256,7 @@ async def generate_chat_completion(
return await generate_function_chat_completion(
request, form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
if model.get("owned_by") == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion(
@@ -291,7 +288,7 @@ chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
@@ -308,13 +305,14 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
model = models[model_id]
try:
data = process_pipeline_outlet_filter(request, data, user, models)
data = await process_pipeline_outlet_filter(request, data, user, models)
except Exception as e:
return Exception(f"Error: {e}")
metadata = {
"chat_id": data["chat_id"],
"message_id": data["id"],
"filter_ids": data.get("filter_ids", []),
"session_id": data["session_id"],
"user_id": user.id,
}
@@ -334,9 +332,16 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
}
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
result, _ = await process_filter_functions(
request=request,
filter_ids=get_sorted_filter_ids(model),
filter_functions=filter_functions,
filter_type="outlet",
form_data=data,
extra_params=extra_params,
@@ -357,7 +362,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
raise Exception(f"Action not found: {action_id}")
if not request.app.state.MODELS:
await get_all_models(request)
await get_all_models(request, user=user)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
@@ -390,11 +395,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
}
)
if action_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
function_module, _, _ = get_function_module_from_cache(request, action_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
@@ -438,7 +439,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
)
)
except Exception as e:
print(e)
log.exception(f"Failed to get user values: {e}")
params = {**params, "__user__": __user__}

View File

@@ -1,148 +1,210 @@
import asyncio
import json
import logging
import uuid
from typing import Optional
import aiohttp
import websockets
import requests
from urllib.parse import urljoin
from pydantic import BaseModel
from open_webui.env import SRC_LOG_LEVELS
logger = logging.getLogger(__name__)
logger.setLevel(SRC_LOG_LEVELS["MAIN"])
async def execute_code_jupyter(
jupyter_url, code, token=None, password=None, timeout=10
):
class ResultModel(BaseModel):
"""
Executes Python code in a Jupyter kernel.
Supports authentication with a token or password.
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
Execute Code Result Model
"""
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
# Authenticate using password
if password and not token:
stdout: Optional[str] = ""
stderr: Optional[str] = ""
result: Optional[str] = ""
class JupyterCodeExecuter:
"""
Execute code in jupyter notebook
"""
def __init__(
self,
base_url: str,
code: str,
token: str = "",
password: str = "",
timeout: int = 60,
):
"""
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 60s)
"""
self.base_url = base_url
self.code = code
self.token = token
self.password = password
self.timeout = timeout
self.kernel_id = ""
if self.base_url[-1] != "/":
self.base_url += "/"
self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
self.params = {}
self.result = ResultModel()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.kernel_id:
try:
async with self.session.delete(
f"api/kernels/{self.kernel_id}", params=self.params
) as response:
response.raise_for_status()
except Exception as err:
logger.exception("close kernel failed, %s", err)
await self.session.close()
async def run(self) -> ResultModel:
try:
login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url)
await self.sign_in()
await self.init_kernel()
await self.execute_code()
except Exception as err:
logger.exception("execute code failed, %s", err)
self.result.stderr = f"Error: {err}"
return self.result
async def sign_in(self) -> None:
# password authentication
if self.password and not self.token:
async with self.session.get("login") as response:
response.raise_for_status()
xsrf_token = response.cookies["_xsrf"].value
if not xsrf_token:
raise ValueError("_xsrf token not found")
self.session.cookie_jar.update_cookies(response.cookies)
self.session.headers.update({"X-XSRFToken": xsrf_token})
async with self.session.post(
"login",
data={"_xsrf": xsrf_token, "password": self.password},
allow_redirects=False,
) as response:
response.raise_for_status()
self.session.cookie_jar.update_cookies(response.cookies)
# token authentication
if self.token:
self.params.update({"token": self.token})
async def init_kernel(self) -> None:
async with self.session.post(url="api/kernels", params=self.params) as response:
response.raise_for_status()
xsrf_token = session.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token")
login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {
"stdout": "",
"stderr": f"Authentication Error: {str(e)}",
"result": "",
}
# Construct API URLs with authentication token if provided
params = f"?token={token}" if token else ""
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
try:
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response.raise_for_status()
kernel_id = response.json()["id"]
websocket_url = urljoin(
jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}",
)
kernel_data = await response.json()
self.kernel_id = kernel_data["id"]
def init_ws(self) -> (str, dict):
ws_base = self.base_url.replace("http", "ws", 1)
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
websocket_url = f"{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
ws_headers = {}
if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
cookies = {name: value for name, value in session.cookies.items()}
ws_headers["Cookie"] = "; ".join(
[f"{name}={value}" for name, value in cookies.items()]
)
if self.password and not self.token:
ws_headers = {
"Cookie": "; ".join(
[
f"{cookie.key}={cookie.value}"
for cookie in self.session.cookie_jar
]
),
**self.session.headers,
}
return websocket_url, ws_headers
async def execute_code(self) -> None:
# initialize ws
websocket_url, ws_headers = self.init_ws()
# execute
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
execute_request = {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
await ws.send(json.dumps(execute_request))
await self.execute_in_jupyter(ws)
stdout, stderr, result = "", "", []
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type")
if msg_type == "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"):
data = message_data["content"]["data"]
if "image/png" in data:
result.append(
f"data:image/png;base64,{data['image/png']}"
)
elif "text/plain" in data:
result.append(data["text/plain"])
elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"])
elif (
msg_type == "status"
and message_data["content"]["execution_state"] == "idle"
):
async def execute_in_jupyter(self, ws) -> None:
# send message
msg_id = uuid.uuid4().hex
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": uuid.uuid4().hex,
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": self.code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
# parse message
stdout, stderr, result = "", "", []
while True:
try:
# wait for message
message = await asyncio.wait_for(ws.recv(), self.timeout)
message_data = json.loads(message)
# msg id not match, skip
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
continue
# check message type
msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
case "execute_result" | "display_data":
data = message_data["content"]["data"]
if "image/png" in data:
result.append(f"data:image/png;base64,{data['image/png']}")
elif "text/plain" in data:
result.append(data["text/plain"])
case "error":
stderr += "\n".join(message_data["content"]["traceback"])
case "status":
if message_data["content"]["execution_state"] == "idle":
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
self.result.stdout = stdout.strip()
self.result.stderr = stderr.strip()
self.result.result = "\n".join(result).strip() if result else ""
except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally:
if kernel_id:
requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
)
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": "\n".join(result).strip() if result else "",
}
async def execute_code_jupyter(
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
) -> dict:
async with JupyterCodeExecuter(
base_url, code, token, password, timeout
) as executor:
result = await executor.run()
return result.model_dump()

View File

@@ -1,46 +1,80 @@
import inspect
from open_webui.utils.plugin import load_function_module_by_id
import logging
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.models.functions import Functions
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model):
def get_function_module(request, function_id, load_from_db=True):
"""
Get the function module by its ID.
"""
function_module, _, _ = get_function_module_from_cache(
request, function_id, load_from_db
)
return function_module
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
if function is not None:
valves = Functions.get_function_valves_by_id(function_id)
return valves.get("priority", 0) if valves else 0
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
active_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
def get_active_status(filter_id):
function_module = get_function_module(request, filter_id)
if getattr(function_module, "toggle", None):
return filter_id in (enabled_filter_ids or [])
return True
active_filter_ids = [
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
]
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority)
return filter_ids
async def process_filter_functions(
request, filter_ids, filter_type, form_data, extra_params
request, filter_functions, filter_type, form_data, extra_params
):
skip_files = None
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
for function in filter_functions:
filter = function
filter_id = function.id
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
function_module = get_function_module(
request, filter_id, load_from_db=(filter_type != "stream")
)
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:
continue
# Check if the function has a file_handler variable
if filter_type == "inlet" and hasattr(function_module, "file_handler"):
@@ -53,15 +87,15 @@ async def process_filter_functions(
**(valves if valves else {})
)
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:
continue
try:
# Prepare parameters
sig = inspect.signature(handler)
params = {"body": form_data} | {
params = {"body": form_data}
if filter_type == "stream":
params = {"event": form_data}
params = params | {
k: v
for k, v in {
**extra_params,
@@ -80,7 +114,7 @@ async def process_filter_functions(
)
)
except Exception as e:
print(e)
log.exception(f"Failed to get user values: {e}")
# Execute handler
if inspect.iscoroutinefunction(handler):
@@ -89,11 +123,12 @@ async def process_filter_functions(
form_data = handler(**params)
except Exception as e:
print(f"Error in {filter_type} handler {filter_id}: {e}")
log.debug(f"Error in {filter_type} handler {filter_id}: {e}")
raise e
# Handle file cleanup for inlet
if skip_files and "files" in form_data.get("metadata", {}):
del form_data["files"]
del form_data["metadata"]["files"]
return form_data, {}

View File

@@ -0,0 +1,140 @@
import json
import logging
import sys
from typing import TYPE_CHECKING
from loguru import logger
from open_webui.env import (
AUDIT_LOG_FILE_ROTATION_SIZE,
AUDIT_LOG_LEVEL,
AUDIT_LOGS_FILE_PATH,
GLOBAL_LOG_LEVEL,
)
if TYPE_CHECKING:
from loguru import Record
def stdout_format(record: "Record") -> str:
"""
Generates a formatted string for log records that are output to the console. This format includes a timestamp, log level, source location (module, function, and line), the log message, and any extra data (serialized as JSON).
Parameters:
record (Record): A Loguru record that contains logging details including time, level, name, function, line, message, and any extra context.
Returns:
str: A formatted log string intended for stdout.
"""
record["extra"]["extra_json"] = json.dumps(record["extra"])
return (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level> - {extra[extra_json]}"
"\n{exception}"
)
class InterceptHandler(logging.Handler):
"""
Intercepts log records from Python's standard logging module
and redirects them to Loguru's logger.
"""
def emit(self, record):
"""
Called by the standard logging module for each log event.
It transforms the standard `LogRecord` into a format compatible with Loguru
and passes it to Loguru's logger.
"""
try:
level = logger.level(record.levelname).name
except ValueError:
level = record.levelno
frame, depth = sys._getframe(6), 6
while frame and frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
def file_format(record: "Record"):
"""
Formats audit log records into a structured JSON string for file output.
Parameters:
record (Record): A Loguru record containing extra audit data.
Returns:
str: A JSON-formatted string representing the audit data.
"""
audit_data = {
"id": record["extra"].get("id", ""),
"timestamp": int(record["time"].timestamp()),
"user": record["extra"].get("user", dict()),
"audit_level": record["extra"].get("audit_level", ""),
"verb": record["extra"].get("verb", ""),
"request_uri": record["extra"].get("request_uri", ""),
"response_status_code": record["extra"].get("response_status_code", 0),
"source_ip": record["extra"].get("source_ip", ""),
"user_agent": record["extra"].get("user_agent", ""),
"request_object": record["extra"].get("request_object", b""),
"response_object": record["extra"].get("response_object", b""),
"extra": record["extra"].get("extra", {}),
}
record["extra"]["file_extra"] = json.dumps(audit_data, default=str)
return "{extra[file_extra]}\n"
def start_logger():
"""
Initializes and configures Loguru's logger with distinct handlers:
A console (stdout) handler for general log messages (excluding those marked as auditable).
An optional file handler for audit logs if audit logging is enabled.
Additionally, this function reconfigures Pythons standard logging to route through Loguru and adjusts logging levels for Uvicorn.
Parameters:
enable_audit_logging (bool): Determines whether audit-specific log entries should be recorded to file.
"""
logger.remove()
logger.add(
sys.stdout,
level=GLOBAL_LOG_LEVEL,
format=stdout_format,
filter=lambda record: "auditable" not in record["extra"],
)
if AUDIT_LOG_LEVEL != "NONE":
try:
logger.add(
AUDIT_LOGS_FILE_PATH,
level="INFO",
rotation=AUDIT_LOG_FILE_ROTATION_SIZE,
compression="zip",
format=file_format,
filter=lambda record: record["extra"].get("auditable") is True,
)
except Exception as e:
logger.error(f"Failed to initialize audit log file handler: {str(e)}")
logging.basicConfig(
handlers=[InterceptHandler()], level=GLOBAL_LOG_LEVEL, force=True
)
for uvicorn_logger_name in ["uvicorn", "uvicorn.error"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = []
for uvicorn_logger_name in ["uvicorn.access"]:
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
uvicorn_logger.setLevel(GLOBAL_LOG_LEVEL)
uvicorn_logger.handlers = [InterceptHandler()]
logger.info(f"GLOBAL_LOG_LEVEL: {GLOBAL_LOG_LEVEL}")

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More