Merge branch 'dev' into pyodide-files

This commit is contained in:
Jarrod Lowe
2025-05-24 09:12:08 +12:00
committed by GitHub
207 changed files with 9220 additions and 5142 deletions

View File

@@ -989,6 +989,26 @@ DEFAULT_USER_ROLE = PersistentConfig(
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
PENDING_USER_OVERLAY_TITLE = PersistentConfig(
"PENDING_USER_OVERLAY_TITLE",
"ui.pending_user_overlay_title",
os.environ.get("PENDING_USER_OVERLAY_TITLE", ""),
)
PENDING_USER_OVERLAY_CONTENT = PersistentConfig(
"PENDING_USER_OVERLAY_CONTENT",
"ui.pending_user_overlay_content",
os.environ.get("PENDING_USER_OVERLAY_CONTENT", ""),
)
RESPONSE_WATERMARK = PersistentConfig(
"RESPONSE_WATERMARK",
"ui.watermark",
os.environ.get("RESPONSE_WATERMARK", ""),
)
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
== "true"
@@ -1731,6 +1751,9 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
ENABLE_QDRANT_MULTITENANCY_MODE = (
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true"
)
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
@@ -1825,6 +1848,18 @@ CONTENT_EXTRACTION_ENGINE = PersistentConfig(
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig(
"EXTERNAL_DOCUMENT_LOADER_URL",
"rag.external_document_loader_url",
os.environ.get("EXTERNAL_DOCUMENT_LOADER_URL", ""),
)
EXTERNAL_DOCUMENT_LOADER_API_KEY = PersistentConfig(
"EXTERNAL_DOCUMENT_LOADER_API_KEY",
"rag.external_document_loader_api_key",
os.environ.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", ""),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
@@ -1849,6 +1884,12 @@ DOCLING_OCR_LANG = PersistentConfig(
os.getenv("DOCLING_OCR_LANG", "eng,fra,deu,spa"),
)
DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
"DOCLING_DO_PICTURE_DESCRIPTION",
"rag.docling_do_picture_description",
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
)
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
@@ -1920,6 +1961,16 @@ RAG_FILE_MAX_SIZE = PersistentConfig(
),
)
RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig(
"RAG_ALLOWED_FILE_EXTENSIONS",
"rag.file.allowed_extensions",
[
ext.strip()
for ext in os.environ.get("RAG_ALLOWED_FILE_EXTENSIONS", "").split(",")
if ext.strip()
],
)
RAG_EMBEDDING_ENGINE = PersistentConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
@@ -2126,6 +2177,12 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
)
BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig(
"BYPASS_WEB_SEARCH_WEB_LOADER",
"rag.web.search.bypass_web_loader",
os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true",
)
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
@@ -2151,6 +2208,7 @@ WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
WEB_LOADER_ENGINE = PersistentConfig(
"WEB_LOADER_ENGINE",
"rag.web.loader.engine",
@@ -2839,6 +2897,12 @@ LDAP_CA_CERT_FILE = PersistentConfig(
os.environ.get("LDAP_CA_CERT_FILE", ""),
)
LDAP_VALIDATE_CERT = PersistentConfig(
"LDAP_VALIDATE_CERT",
"ldap.server.validate_cert",
os.environ.get("LDAP_VALIDATE_CERT", "True").lower() == "true",
)
LDAP_CIPHERS = PersistentConfig(
"LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
)

View File

@@ -54,11 +54,8 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
if pipe_id not in request.app.state.FUNCTIONS:
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = request.app.state.FUNCTIONS[pipe_id]
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)

View File

@@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url):
db = connect(db_url, unquote_password=True)
db = connect(db_url, unquote_user=True, unquote_password=True)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
@@ -51,7 +51,7 @@ def register_connection(db_url):
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url, unquote_password=True)
connection = parse(db_url, unquote_user=True, unquote_password=True)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(**connection)

View File

@@ -197,6 +197,7 @@ from open_webui.config import (
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_RELEVANCE_THRESHOLD,
RAG_ALLOWED_FILE_EXTENSIONS,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
@@ -206,10 +207,13 @@ from open_webui.config import (
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL,
DOCLING_SERVER_URL,
DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
@@ -224,6 +228,7 @@ from open_webui.config import (
ENABLE_WEB_SEARCH,
WEB_SEARCH_ENGINE,
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
BYPASS_WEB_SEARCH_WEB_LOADER,
WEB_SEARCH_RESULT_COUNT,
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
@@ -291,6 +296,8 @@ from open_webui.config import (
ENABLE_EVALUATION_ARENA_MODELS,
USER_PERMISSIONS,
DEFAULT_USER_ROLE,
PENDING_USER_OVERLAY_CONTENT,
PENDING_USER_OVERLAY_TITLE,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_MODELS,
DEFAULT_ARENA_MODEL,
@@ -317,6 +324,7 @@ from open_webui.config import (
LDAP_APP_PASSWORD,
LDAP_USE_TLS,
LDAP_CA_CERT_FILE,
LDAP_VALIDATE_CERT,
LDAP_CIPHERS,
# Misc
ENV,
@@ -327,6 +335,7 @@ from open_webui.config import (
DEFAULT_LOCALE,
OAUTH_PROVIDERS,
WEBUI_URL,
RESPONSE_WATERMARK,
# Admin
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
@@ -373,6 +382,7 @@ from open_webui.env import (
OFFLINE_MODE,
ENABLE_OTEL,
EXTERNAL_PWA_MANIFEST_URL,
AIOHTTP_CLIENT_SESSION_SSL,
)
@@ -573,6 +583,11 @@ app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
@@ -609,6 +624,7 @@ app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_VALIDATE_CERT = LDAP_VALIDATE_CERT
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
@@ -631,6 +647,7 @@ app.state.FUNCTIONS = {}
app.state.config.TOP_K = RAG_TOP_K
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
@@ -641,10 +658,13 @@ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
@@ -688,6 +708,7 @@ app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
@@ -1167,11 +1188,12 @@ async def chat_completion(
"chat_id": form_data.pop("chat_id", None),
"message_id": form_data.pop("id", None),
"session_id": form_data.pop("session_id", None),
"filter_ids": form_data.pop("filter_ids", []),
"tool_ids": form_data.get("tool_ids", None),
"tool_servers": form_data.pop("tool_servers", None),
"files": form_data.get("files", None),
"features": form_data.get("features", None),
"variables": form_data.get("variables", None),
"features": form_data.get("features", {}),
"variables": form_data.get("variables", {}),
"model": model,
"direct": model_item.get("direct", False),
**(
@@ -1395,6 +1417,11 @@ async def get_app_config(request: Request):
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
"sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value,
},
"ui": {
"pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE,
"pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT,
"response_watermark": app.state.config.RESPONSE_WATERMARK,
},
"license_metadata": app.state.LICENSE_METADATA,
**(
{
@@ -1446,7 +1473,8 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
"https://api.github.com/repos/open-webui/open-webui/releases/latest",
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
response.raise_for_status()
data = await response.json()

View File

@@ -129,12 +129,16 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email)
if not user:
return None
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
return user
else:
return None

View File

@@ -0,0 +1,58 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalDocumentLoader(BaseLoader):
def __init__(
self,
file_path,
url: str,
api_key: str,
mime_type=None,
**kwargs,
) -> None:
self.url = url
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
headers = {}
if self.mime_type is not None:
headers["Content-Type"] = self.mime_type
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
url = self.url
if url.endswith("/"):
url = url[:-1]
r = requests.put(f"{url}/process", data=data, headers=headers)
if r.ok:
res = r.json()
if res:
return [
Document(
page_content=res.get("page_content"),
metadata=res.get("metadata"),
)
]
else:
raise Exception("Error loading document: No content returned")
else:
raise Exception(f"Error loading document: {r.status_code} {r.text}")

View File

@@ -10,7 +10,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalLoader(BaseLoader):
class ExternalWebLoader(BaseLoader):
def __init__(
self,
web_paths: Union[str, List[str]],
@@ -32,7 +32,7 @@ class ExternalLoader(BaseLoader):
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
"Authorization": f"Bearer {self.external_api_key}",
},
json={

View File

@@ -21,6 +21,8 @@ from langchain_community.document_loaders import (
)
from langchain_core.documents import Document
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
@@ -126,14 +128,12 @@ class TikaLoader:
class DoclingLoader:
def __init__(
self, url, file_path=None, mime_type=None, ocr_engine=None, ocr_lang=None
):
def __init__(self, url, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.file_path = file_path
self.mime_type = mime_type
self.ocr_engine = ocr_engine
self.ocr_lang = ocr_lang
self.params = params or {}
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
@@ -150,11 +150,19 @@ class DoclingLoader:
"table_mode": "accurate",
}
if self.ocr_engine and self.ocr_lang:
params["ocr_engine"] = self.ocr_engine
params["ocr_lang"] = [
lang.strip() for lang in self.ocr_lang.split(",") if lang.strip()
]
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
lang.strip()
for lang in self.params.get("ocr_lang").split(",")
if lang.strip()
]
endpoint = f"{self.url}/v1alpha/convert/file"
r = requests.post(endpoint, files=files, data=params)
@@ -207,7 +215,18 @@ class Loader:
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if (
self.engine == "external"
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
):
loader = ExternalDocumentLoader(
file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
mime_type=file_content_type,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
@@ -225,8 +244,13 @@ class Loader:
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
ocr_engine=self.kwargs.get("DOCLING_OCR_ENGINE"),
ocr_lang=self.kwargs.get("DOCLING_OCR_LANG"),
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
)
elif (
self.engine == "document_intelligence"
@@ -258,6 +282,15 @@ class Loader:
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
elif (
self.engine == "external"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
)
else:
if file_ext == "pdf":
loader = PyPDFLoader(

View File

@@ -1,8 +1,12 @@
import requests
import aiohttp
import asyncio
import logging
import os
import sys
import time
from typing import List, Dict, Any
from contextlib import asynccontextmanager
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
@@ -14,18 +18,29 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
"""
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__(self, api_key: str, file_path: str):
def __init__(
self,
api_key: str,
file_path: str,
timeout: int = 300, # 5 minutes default
max_retries: int = 3,
enable_debug_logging: bool = False,
):
"""
Initializes the loader.
Initializes the loader with enhanced features.
Args:
api_key: Your Mistral API key.
file_path: The local path to the PDF file to process.
timeout: Request timeout in seconds.
max_retries: Maximum number of retry attempts.
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
@@ -34,7 +49,23 @@ class MistralLoader:
self.api_key = api_key
self.file_path = file_path
self.headers = {"Authorization": f"Bearer {self.api_key}"}
self.timeout = timeout
self.max_retries = max_retries
self.debug = enable_debug_logging
# Pre-compute file info for performance
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0",
}
def _debug_log(self, message: str, *args) -> None:
"""Conditional debug logging for performance."""
if self.debug:
log.debug(message, *args)
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
"""Checks response status and returns JSON content."""
@@ -54,24 +85,89 @@ class MistralLoader:
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
raise # Re-raise after logging
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
response.raise_for_status()
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
raise
def _retry_request_sync(self, request_func, *args, **kwargs):
"""Synchronous retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""Async retry logic with exponential backoff."""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1:
raise
wait_time = (2**attempt) + 0.5
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
def _upload_file(self) -> str:
"""Uploads the file to Mistral for OCR processing."""
"""Uploads the file to Mistral for OCR processing (sync version)."""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
try:
def upload_request():
with open(self.file_path, "rb") as f:
files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
upload_headers = self.headers.copy() # Avoid modifying self.headers
response = requests.post(
url, headers=upload_headers, files=files, data=data
url,
headers=self.headers,
files=files,
data=data,
timeout=self.timeout,
)
response_data = self._handle_response(response)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
@@ -81,16 +177,66 @@ class MistralLoader:
log.error(f"Failed to upload file: {e}")
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.BASE_API_URL}/files"
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
async with session.post(
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file."""
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.timeout
)
return self._handle_response(response)
try:
response = requests.get(url, headers=signed_url_headers, params=params)
response_data = self._handle_response(response)
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
@@ -100,8 +246,36 @@ class MistralLoader:
log.error(f"Failed to get signed URL: {e}")
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval."""
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
async with session.get(
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
self._debug_log("Signed URL received successfully")
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing."""
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.BASE_API_URL}/ocr"
ocr_headers = {
@@ -118,43 +292,198 @@ class MistralLoader:
"include_image_base64": False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.timeout
)
return self._handle_response(response)
try:
response = requests.post(url, headers=ocr_headers, json=payload)
ocr_response = self._handle_response(response)
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
log.debug("OCR response: %s", ocr_response)
self._debug_log("OCR response: %s", ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.BASE_API_URL}/ocr"
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
start_time = time.time()
async with session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
return ocr_response
return await self._retry_request_async(ocr_request)
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage."""
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}"
# No specific Accept header needed, default or Authorization is usually sufficient
try:
response = requests.delete(url, headers=self.headers)
delete_response = self._handle_response(
response
) # Check status, ignore response body unless needed
log.info(
f"File deleted successfully: {delete_response}"
) # Log the response if available
response = requests.delete(url, headers=self.headers, timeout=30)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
# Depending on requirements, you might choose to raise the error here
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete(
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=30
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
@asynccontextmanager
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=10, # Total connection limit
limit_per_host=5, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
keepalive_timeout=30,
enable_cleanup_closed=True,
)
async with aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self.timeout),
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found", metadata={"error": "no_pages"}
)
]
documents = []
total_pages = len(pages_data)
skipped_pages = 0
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No valid text content found in document",
metadata={"error": "no_valid_pages", "total_pages": total_pages},
)
]
return documents
def load(self) -> List[Document]:
"""
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
Synchronous version for backward compatibility.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
# 1. Upload file
file_id = self._upload_file()
@@ -166,53 +495,30 @@ class MistralLoader:
ocr_response = self._process_ocr(signed_url)
# 4. Process results
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [Document(page_content="No text content found", metadata={})]
documents = self._process_results(ocr_response)
documents = []
total_pages = len(pages_data)
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
documents.append(
Document(
page_content=page_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
# Add other relevant metadata from page_data if available/needed
# e.g., page_data.get('width'), page_data.get('height')
},
)
)
else:
log.warning(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No text content found in valid pages", metadata={}
)
]
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
log.error(f"An error occurred during the loading process: {e}")
# Return an empty list or a specific error document on failure
return [Document(page_content=f"Error during processing: {e}", metadata={})]
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Delete file (attempt even if prior steps failed after upload)
if file_id:
@@ -223,3 +529,105 @@ class MistralLoader:
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
async def load_async(self) -> List[Document]:
"""
Asynchronous OCR workflow execution with optimized performance.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
async with self._get_session() as session:
# 1. Upload file with streaming
file_id = await self._upload_file_async(session)
# 2. Get signed URL
signed_url = await self._get_signed_url_async(session, file_id)
# 3. Process OCR
ocr_response = await self._process_ocr_async(session, signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
return [
Document(
page_content=f"Error during OCR processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Cleanup - always attempt file deletion
if file_id:
try:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
) -> List[List[Document]]:
"""
Process multiple files concurrently for maximum performance.
Args:
loaders: List of MistralLoader instances
Returns:
List of document lists, one for each loader
"""
if not loaders:
return []
log.info(f"Starting concurrent processing of {len(loaders)} files")
start_time = time.time()
# Process all files concurrently
tasks = [loader.load_async() for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
metadata={
"error": "batch_processing_failed",
"file_index": i,
},
)
]
)
else:
processed_results.append(result)
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
log.info(
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
)
return processed_results

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple
class BaseReranker(ABC):
@abstractmethod
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
pass

View File

@@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT:
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -3,12 +3,14 @@ import requests
from typing import Optional, List, Tuple
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker:
class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,

View File

@@ -12,7 +12,7 @@ from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files

View File

@@ -1,30 +0,0 @@
from open_webui.config import VECTOR_DB
if VECTOR_DB == "milvus":
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
VECTOR_DB_CLIENT = MilvusClient()
elif VECTOR_DB == "qdrant":
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
VECTOR_DB_CLIENT = QdrantClient()
elif VECTOR_DB == "opensearch":
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
VECTOR_DB_CLIENT = OpenSearchClient()
elif VECTOR_DB == "pgvector":
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
VECTOR_DB_CLIENT = PgvectorClient()
elif VECTOR_DB == "elasticsearch":
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
VECTOR_DB_CLIENT = ElasticsearchClient()
elif VECTOR_DB == "pinecone":
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
VECTOR_DB_CLIENT = PineconeClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
VECTOR_DB_CLIENT = ChromaClient()

View File

@@ -1,13 +1,12 @@
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from pinecone import ServerlessSpec
from pinecone import Pinecone, ServerlessSpec
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
from open_webui.retrieval.vector.main import (
VectorDBBase,
@@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase):
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone gRPC client for improved performance
self.client = PineconeGRPC(
api_key=self.api_key, environment=self.environment, cloud=self.cloud
)
# Initialize Pinecone client for improved performance
self.client = Pinecone(api_key=self.api_key)
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase):
metadatas = []
for match in matches:
metadata = match.get("metadata", {})
ids.append(match["id"])
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
@@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
return len(response.matches) > 0
matches = getattr(response, "matches", []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
@@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase):
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Perform a streaming upsert over gRPC for performance testing."""
if not items:
log.warning("No items to upsert via streaming")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Open a streaming upsert channel
stream = self.index.streaming_upsert()
try:
for point in points:
# send each point over the stream
stream.send(point)
# close the stream to finalize
stream.close()
log.info(
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
)
except Exception as e:
log.error(f"Error during streaming upsert: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
@@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
if not query_response.matches:
matches = getattr(query_response, "matches", []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
@@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase):
)
# Convert to GetResult format
get_result = self._result_to_get_result(query_response.matches)
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(match.score)
for match in query_response.matches
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
@@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase):
include_metadata=True,
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
@@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
@@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase):
raise
def close(self):
"""Shut down the gRPC channel and thread pool."""
"""Shut down resources."""
try:
self.client.close()
log.info("Pinecone gRPC channel closed.")
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to close Pinecone gRPC channel: {e}")
log.warning(f"Failed to clean up Pinecone resources: {e}")
self._executor.shutdown(wait=True)
def __enter__(self):

View File

@@ -0,0 +1,712 @@
import logging
from typing import Optional, Tuple
from urllib.parse import urlparse
import grpc
from open_webui.config import (
QDRANT_API_KEY,
QDRANT_GRPC_PORT,
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult:
ids = []
documents = []
metadatas = []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
"""
Maps the traditional collection name to multi-tenant collection and tenant ID.
Returns:
tuple: (collection_name, tenant_id)
"""
# Check for user memory collections
tenant_id = collection_name
if collection_name.startswith("user-memory-"):
return self.MEMORY_COLLECTION, tenant_id
# Check for file collections
elif collection_name.startswith("file-"):
return self.FILE_COLLECTION, tenant_id
# Check for web search collections
elif collection_name.startswith("web-search-"):
return self.WEB_SEARCH_COLLECTION, tenant_id
# Handle hash-based collections (YouTube and web URLs)
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, tenant_id
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
def _extract_error_message(self, exception):
"""
Extract error message from either HTTP or gRPC exceptions
Returns:
tuple: (status_code, error_message)
"""
# Check if it's an HTTP exception
if isinstance(exception, UnexpectedResponse):
try:
error_data = exception.structured()
error_msg = error_data.get("status", {}).get("error", "")
return exception.status_code, error_msg
except Exception as inner_e:
log.error(f"Failed to parse HTTP error: {inner_e}")
return exception.status_code, str(exception)
# Check if it's a gRPC exception
elif isinstance(exception, grpc.RpcError):
# Extract status code from gRPC error
status_code = None
if hasattr(exception, "code") and callable(exception.code):
status_code = exception.code().value[0]
# Extract error message
error_msg = str(exception)
if "details =" in error_msg:
# Parse the details line which contains the actual error message
try:
details_line = [
line.strip()
for line in error_msg.split("\n")
if "details =" in line
][0]
error_msg = details_line.split("details =")[1].strip(' "')
except (IndexError, AttributeError):
# Fall back to full message if parsing fails
pass
return status_code, error_msg
# For any other type of exception
return None, str(exception)
def _is_collection_not_found_error(self, exception):
"""
Check if the exception is due to collection not found, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# HTTP error (404)
if (
status_code == 404
and "Collection" in error_msg
and "doesn't exist" in error_msg
):
return True
# gRPC error (NOT_FOUND status)
if (
isinstance(exception, grpc.RpcError)
and exception.code() == grpc.StatusCode.NOT_FOUND
):
return True
return False
def _is_dimension_mismatch_error(self, exception):
"""
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
"""
status_code, error_msg = self._extract_error_message(exception)
# Common patterns in both HTTP and gRPC
return (
"Vector dimension error" in error_msg
or "dimensions mismatch" in error_msg
or "invalid vector size" in error_msg
)
def _create_multi_tenant_collection_if_not_exists(
self, mt_collection_name: str, dimension: int = 384
):
"""
Creates a collection with multi-tenancy configuration if it doesn't exist.
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
"""
try:
# Try to create the collection directly - will fail if it already exists
self.client.create_collection(
collection_name=mt_collection_name,
vectors_config=models.VectorParams(
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
hnsw_config=models.HnswConfigDiff(
payload_m=16, # Enable per-tenant indexing
m=0,
on_disk=self.QDRANT_ON_DISK,
),
)
# Create tenant ID payload index
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name="tenant_id",
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=True,
on_disk=self.QDRANT_ON_DISK,
),
wait=True,
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
except (UnexpectedResponse, grpc.RpcError) as e:
# Check for the specific error indicating collection already exists
status_code, error_msg = self._extract_error_message(e)
# HTTP status code 409 or gRPC ALREADY_EXISTS
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
isinstance(e, grpc.RpcError)
and e.code() == grpc.StatusCode.ALREADY_EXISTS
):
if "already exists" in error_msg:
log.debug(f"Collection {mt_collection_name} already exists")
return
# If it's not an already exists error, re-raise
raise e
except Exception as e:
raise e
def _create_points(self, items: list[VectorItem], tenant_id: str):
"""
Create point structs from vector items with tenant ID.
"""
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"],
"tenant_id": tenant_id,
},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
"""
Check if a logical collection exists by checking for any points with the tenant ID.
"""
if not self.client:
return False
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try directly querying - most of the time collection should exist
response = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=1,
)
# Collection exists with this tenant ID if there are points
return len(response.points) > 0
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist")
return False
else:
# For other API errors, log and return False
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
return False
except Exception as e:
# For any other errors, log and return False
log.debug(f"Error checking collection {mt_collection}: {e}")
return False
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
"""
Delete vectors by ID or filter from a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = []
if ids:
for id_value in ids:
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter:
for key, value in filter.items():
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try:
# Try to delete directly - most of the time collection should exist
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
)
return update_result
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, nothing to delete"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get the vector dimension from the query vector
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
try:
# Try the search operation directly - most of the time collection should exist
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Ensure vector dimensions match the collection
collection_dim = self.client.get_collection(
mt_collection
).config.params.vectors.size
if collection_dim != dimension:
if collection_dim < dimension:
vectors = [vector[:collection_dim] for vector in vectors]
else:
vectors = [
vector + [0] * (collection_dim - dimension)
for vector in vectors
]
# Search with tenant filter
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, search returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
"""
Query points with filters and tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Set default limit if not provided
if limit is None:
limit = NO_LIMIT
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
# Create metadata filters
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
# Combine tenant filter with metadata filters
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try:
# Try the query directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=combined_filter,
limit=limit,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(
f"Collection {mt_collection} doesn't exist, query returns None"
)
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and re-raise
log.exception(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""
Get all items in a collection with tenant isolation.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Create tenant filter
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
try:
# Try to get points directly - most of the time collection should exist
points = self.client.query_points(
collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
return self._result_to_get_result(points.points)
except (UnexpectedResponse, grpc.RpcError) as e:
if self._is_collection_not_found_error(e):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None
else:
# For other API errors, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, log and return None
log.exception(f"Error getting collection '{collection_name}': {e}")
return None
def _handle_operation_with_error_retry(
self, operation_name, mt_collection, points, dimension
):
"""
Private helper to handle common error cases for insert and upsert operations.
Args:
operation_name: 'insert' or 'upsert'
mt_collection: The multi-tenant collection name
points: The vector points to insert/upsert
dimension: The dimension of the vectors
Returns:
The operation result (for upsert) or None (for insert)
"""
try:
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
except (UnexpectedResponse, grpc.RpcError) as e:
# Handle collection not found
if self._is_collection_not_found_error(e):
log.info(
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
)
# Create collection with correct dimensions from our vectors
self._create_multi_tenant_collection_if_not_exists(
mt_collection_name=mt_collection, dimension=dimension
)
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
# Handle dimension mismatch
elif self._is_dimension_mismatch_error(e):
# For dimension errors, the collection must exist, so get its configuration
mt_collection_info = self.client.get_collection(mt_collection)
existing_size = mt_collection_info.config.params.vectors.size
log.info(
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
)
if existing_size < dimension:
# Truncate vectors to fit
log.info(
f"Truncating vectors from {dimension} to {existing_size} dimensions"
)
points = [
PointStruct(
id=point.id,
vector=point.vector[:existing_size],
payload=point.payload,
)
for point in points
]
elif existing_size > dimension:
# Pad vectors with zeros
log.info(
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
)
points = [
PointStruct(
id=point.id,
vector=point.vector
+ [0] * (existing_size - len(point.vector)),
payload=point.payload,
)
for point in points
]
# Try operation again with adjusted dimensions
if operation_name == "insert":
self.client.upload_points(mt_collection, points)
return None
else: # upsert
return self.client.upsert(mt_collection, points)
else:
# Not a known error we can handle, log and re-raise
_, error_msg = self._extract_error_message(e)
log.warning(f"Unhandled Qdrant error: {error_msg}")
raise
except Exception as e:
# For non-Qdrant exceptions, re-raise
raise
def insert(self, collection_name: str, items: list[VectorItem]):
"""
Insert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"insert", mt_collection, points, dimension
)
def upsert(self, collection_name: str, items: list[VectorItem]):
"""
Upsert items with tenant ID.
"""
if not self.client or not items:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
# Get dimensions from the actual vectors
dimension = len(items[0]["vector"]) if items else None
# Create points with tenant ID
points = self._create_points(items, tenant_id)
# Handle the operation with error retry
return self._handle_operation_with_error_retry(
"upsert", mt_collection, points, dimension
)
def reset(self):
"""
Reset the database by deleting all collections.
"""
if not self.client:
return None
collection_names = self.client.get_collections().collections
for collection_name in collection_names:
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str):
"""
Delete a collection.
"""
if not self.client:
return None
# Map to multi-tenant collection and tenant ID
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
tenant_filter = models.FieldCondition(
key="tenant_id", match=models.MatchValue(value=tenant_id)
)
field_conditions = [tenant_filter]
update_result = self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
)
if self.client.get_collection(mt_collection).points_count == 0:
self.client.delete_collection(mt_collection)
return update_result

View File

@@ -0,0 +1,55 @@
from open_webui.retrieval.vector.main import VectorDBBase
from open_webui.retrieval.vector.type import VectorType
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
class Vector:
@staticmethod
def get_vector(vector_type: str) -> VectorDBBase:
"""
get vector db instance by vector type
"""
match vector_type:
case VectorType.MILVUS:
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
return MilvusClient()
case VectorType.QDRANT:
if ENABLE_QDRANT_MULTITENANCY_MODE:
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
QdrantClient,
)
return QdrantClient()
else:
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
return QdrantClient()
case VectorType.PINECONE:
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
return PineconeClient()
case VectorType.OPENSEARCH:
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
return OpenSearchClient()
case VectorType.PGVECTOR:
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
return PgvectorClient()
case VectorType.ELASTICSEARCH:
from open_webui.retrieval.vector.dbs.elasticsearch import (
ElasticsearchClient,
)
return ElasticsearchClient()
case VectorType.CHROMA:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
return ChromaClient()
case _:
raise ValueError(f"Unsupported vector type: {vector_type}")
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)

View File

@@ -0,0 +1,11 @@
from enum import StrEnum
class VectorType(StrEnum):
MILVUS = "milvus"
QDRANT = "qdrant"
CHROMA = "chroma"
PINECONE = "pinecone"
ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch"
PGVECTOR = "pgvector"

View File

@@ -42,7 +42,9 @@ def search_searchapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -42,7 +42,9 @@ def search_serpapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -25,7 +25,7 @@ from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external import ExternalLoader
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
@@ -39,7 +39,7 @@ from open_webui.config import (
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -515,7 +515,8 @@ class SafeWebBaseLoader(WebBaseLoader):
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
url,
**(self.requests_kwargs | kwargs),
) as response:
if self.raise_for_status:
response.raise_for_status()
@@ -628,7 +629,7 @@ def get_web_loader(
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
if WEB_LOADER_ENGINE.value == "external":
WebLoaderClass = ExternalLoader
WebLoaderClass = ExternalWebLoader
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value

View File

@@ -7,6 +7,9 @@ from functools import lru_cache
from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import aiohttp
import aiofiles
@@ -17,6 +20,7 @@ from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -38,6 +42,7 @@ from open_webui.config import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
ENV,
SRC_LOG_LEVELS,
@@ -49,7 +54,7 @@ from open_webui.env import (
router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
AZURE_MAX_FILE_SIZE_MB = 200
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
@@ -71,35 +76,50 @@ from pydub import AudioSegment
from pydub.utils import mediainfo
def get_audio_convert_format(file_path):
"""Check if the given file needs to be converted to a different format."""
def is_audio_conversion_required(file_path):
"""
Check if the given audio file needs conversion to mp3.
"""
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
if not os.path.isfile(file_path):
log.error(f"File not found: {file_path}")
return False
try:
info = mediainfo(file_path)
codec_name = info.get("codec_name", "").lower()
codec_type = info.get("codec_type", "").lower()
codec_tag_string = info.get("codec_tag_string", "").lower()
if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
# File is AAC/mp4a audio, recommend mp3 conversion
return True
# If the codec name or file extension is in the supported formats
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
codec_name in SUPPORTED_FORMATS
or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS
):
return "mp4"
elif info.get("format_name") == "ogg":
return "ogg"
return False # Already supported
return True
except Exception as e:
log.error(f"Error getting audio format: {e}")
return False
return None
def convert_audio_to_wav(file_path, output_path, conversion_type):
"""Convert MP4/OGG audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format=conversion_type)
audio.export(output_path, format="wav")
log.info(f"Converted {file_path} to {output_path}")
def convert_audio_to_mp3(file_path):
"""Convert audio file to mp3 format."""
try:
output_path = os.path.splitext(file_path)[0] + ".mp3"
audio = AudioSegment.from_file(file_path)
audio.export(output_path, format="mp3")
log.info(f"Converted {file_path} to {output_path}")
return output_path
except Exception as e:
log.error(f"Error converting audio file: {e}")
return None
def set_faster_whisper_model(model: str, auto_update: bool = False):
@@ -326,6 +346,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -381,6 +402,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"Content-Type": "application/json",
"xi-api-key": request.app.state.config.TTS_API_KEY,
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -439,6 +461,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
"X-Microsoft-OutputFormat": output_format,
},
data=data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
@@ -507,12 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
def transcribe(request: Request, file_path):
log.info(f"transcribe: {file_path}")
def transcription_handler(request, file_path, metadata):
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
metadata = metadata or {}
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -524,7 +548,7 @@ def transcribe(request: Request, file_path):
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=WHISPER_LANGUAGE,
language=metadata.get("language") or WHISPER_LANGUAGE,
)
log.info(
"Detected language '%s' with probability %f"
@@ -542,19 +566,6 @@ def transcribe(request: Request, file_path):
log.debug(data)
return data
elif request.app.state.config.STT_ENGINE == "openai":
convert_format = get_audio_convert_format(file_path)
if convert_format:
ext = convert_format.split(".")[-1]
os.rename(file_path, file_path.replace(".{ext}", f".{convert_format}"))
# Convert unsupported audio file to WAV format
convert_audio_to_wav(
file_path.replace(".{ext}", f".{convert_format}"),
file_path,
convert_format,
)
r = None
try:
r = requests.post(
@@ -563,7 +574,14 @@ def transcribe(request: Request, file_path):
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
data={
"model": request.app.state.config.STT_MODEL,
**(
{"language": metadata.get("language")}
if metadata.get("language")
else {}
),
},
)
r.raise_for_status()
@@ -771,41 +789,135 @@ def transcribe(request: Request, file_path):
)
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path):
file_path = convert_audio_to_mp3(file_path)
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
# Always produce a list of chunk paths (could be one entry if small)
try:
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
print(f"Chunk paths: {chunk_paths}")
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
results = []
try:
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path, metadata)
for chunk_path in chunk_paths
]
# Gather results as they complete
for future in futures:
try:
results.append(future.result())
except Exception as transcribe_exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error transcribing chunk: {transcribe_exc}",
)
finally:
# Clean up only the temporary chunks, never the original file
for chunk_path in chunk_paths:
if chunk_path != file_path and os.path.isfile(chunk_path):
try:
os.remove(chunk_path)
except Exception:
pass
return {
"text": " ".join([result["text"] for result in results]),
}
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
id = os.path.splitext(os.path.basename(file_path))[
0
] # Handles names with multiple dots
file_dir = os.path.dirname(file_path)
audio = AudioSegment.from_file(file_path)
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
compressed_path = f"{file_dir}/{id}_compressed.opus"
audio.export(compressed_path, format="opus", bitrate="32k")
log.debug(f"Compressed audio to {compressed_path}")
if (
os.path.getsize(compressed_path) > MAX_FILE_SIZE
): # Still larger than MAX_FILE_SIZE after compression
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
audio.export(compressed_path, format="mp3", bitrate="32k")
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
return compressed_path
else:
return file_path
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
"""
Splits audio into chunks not exceeding max_bytes.
Returns a list of chunk file paths. If audio fits, returns list with original path.
"""
file_size = os.path.getsize(file_path)
if file_size <= max_bytes:
return [file_path] # Nothing to split
audio = AudioSegment.from_file(file_path)
duration_ms = len(audio)
orig_size = file_size
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
chunks = []
start = 0
i = 0
base, _ = os.path.splitext(file_path)
while start < duration_ms:
end = min(start + approx_chunk_ms, duration_ms)
chunk = audio[start:end]
chunk_path = f"{base}_chunk_{i}.{format}"
chunk.export(chunk_path, format=format, bitrate=bitrate)
# Reduce chunk duration if still too large
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
end = start + ((end - start) // 2)
chunk = audio[start:end]
chunk.export(chunk_path, format=format, bitrate=bitrate)
if os.path.getsize(chunk_path) > max_bytes:
os.remove(chunk_path)
raise Exception("Audio chunk cannot be reduced below max file size.")
chunks.append(chunk_path)
start = end
i += 1
return chunks
@router.post("/transcriptions")
def transcription(
request: Request,
file: UploadFile = File(...),
language: Optional[str] = Form(None),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
supported_filetypes = (
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/x-m4a",
"audio/webm",
)
if not file.content_type.startswith(supported_filetypes):
SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
if not (
file.content_type.startswith("audio/")
or file.content_type in SUPPORTED_CONTENT_TYPES
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
@@ -826,19 +938,18 @@ def transcription(
f.write(contents)
try:
try:
file_path = compress_audio(file_path)
except Exception as e:
log.exception(e)
metadata = None
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
if language:
metadata = {"language": language}
result = transcribe(request, file_path, metadata)
return {
**result,
"filename": os.path.basename(file_path),
}
data = transcribe(request, file_path)
file_path = file_path.split("/")[-1]
return {**data, "filename": file_path}
except Exception as e:
log.exception(e)

View File

@@ -31,7 +31,7 @@ from open_webui.env import (
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response
from fastapi.responses import RedirectResponse, Response, JSONResponse
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
from pydantic import BaseModel
@@ -51,7 +51,7 @@ from open_webui.utils.access_control import get_permissions
from typing import Optional, List
from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
if ENABLE_LDAP.value:
from ldap3 import Server, Connection, NONE, Tls
@@ -186,6 +186,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_VALIDATE_CERT = (
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
)
LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS
@@ -197,7 +200,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
try:
tls = Tls(
validate=CERT_REQUIRED,
validate=LDAP_VALIDATE_CERT,
version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS,
@@ -478,10 +481,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
)
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
if len(form_data.password.encode("utf-8")) > 72:
raise HTTPException(
@@ -541,6 +540,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
user.id, request.app.state.config.USER_PERMISSIONS
)
if user_count == 0:
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
return {
"token": token,
"token_type": "Bearer",
@@ -574,9 +577,14 @@ async def signout(request: Request, response: Response):
logout_url = openid_data.get("end_session_endpoint")
if logout_url:
response.delete_cookie("oauth_id_token")
return RedirectResponse(
return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
},
headers=response.headers,
url=f"{logout_url}?id_token_hint={oauth_id_token}",
)
else:
raise HTTPException(
@@ -591,12 +599,18 @@ async def signout(request: Request, response: Response):
)
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
return RedirectResponse(
return JSONResponse(
status_code=200,
content={
"status": True,
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
},
headers=response.headers,
url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
)
return {"status": True}
return JSONResponse(
status_code=200, content={"status": True}, headers=response.headers
)
############################
@@ -696,6 +710,9 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
}
@@ -713,6 +730,9 @@ class AdminConfig(BaseModel):
ENABLE_CHANNELS: bool
ENABLE_NOTES: bool
ENABLE_USER_WEBHOOKS: bool
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
RESPONSE_WATERMARK: Optional[str] = None
@router.post("/admin/config")
@@ -750,6 +770,15 @@ async def update_admin_config(
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
form_data.PENDING_USER_OVERLAY_TITLE
)
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
form_data.PENDING_USER_OVERLAY_CONTENT
)
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"WEBUI_URL": request.app.state.config.WEBUI_URL,
@@ -764,6 +793,9 @@ async def update_admin_config(
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
}
@@ -779,6 +811,7 @@ class LdapServerConfig(BaseModel):
search_filters: str = ""
use_tls: bool = True
certificate_path: Optional[str] = None
validate_cert: bool = True
ciphers: Optional[str] = "ALL"
@@ -796,6 +829,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}
@@ -831,6 +865,7 @@ async def update_ldap_server(
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
return {
@@ -845,6 +880,7 @@ async def update_ldap_server(
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
"ciphers": request.app.state.config.LDAP_CIPHERS,
}

View File

@@ -74,13 +74,17 @@ class FeedbackUserResponse(FeedbackResponse):
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackUserResponse(
**feedback.model_dump(),
user=UserResponse(**Users.get_user_by_id(feedback.user_id).model_dump()),
feedback_list = []
for feedback in feedbacks:
user = Users.get_user_by_id(feedback.user_id)
feedback_list.append(
FeedbackUserResponse(
**feedback.model_dump(),
user=UserResponse(**user.model_dump()) if user else None,
)
)
for feedback in feedbacks
]
return feedback_list
@router.delete("/feedbacks/all")
@@ -92,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
async def get_all_feedbacks(user=Depends(get_admin_user)):
feedbacks = Feedbacks.get_all_feedbacks()
return [
FeedbackModel(
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
)
for feedback in feedbacks
]
return feedbacks
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])

View File

@@ -1,6 +1,7 @@
import logging
import os
import uuid
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import Optional
@@ -10,6 +11,7 @@ from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -84,17 +86,44 @@ def has_access_to_file(
def upload_file(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = None,
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
file_metadata = file_metadata if file_metadata else {}
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
)
file_metadata = metadata if metadata else {}
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
file_extension = os.path.splitext(filename)[1]
# Remove the leading dot from the file extension
file_extension = file_extension[1:] if file_extension else ""
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
]
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(
f"File type {file_extension} is not allowed"
),
)
# replace filename with uuid
id = str(uuid.uuid4())
name = filename
@@ -125,33 +154,26 @@ def upload_file(
)
if process:
try:
if file.content_type:
if file.content_type.startswith("audio/") or file.content_type in {
"video/webm"
}:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path, file_metadata)
if file.content_type.startswith(
(
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/x-m4a",
"audio/webm",
"video/webm",
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=id), user=user)
else:
log.info(
f"File type {file.content_type} is not provided, but trying to process anyway"
)
):
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
elif file.content_type not in [
"image/png",
"image/jpeg",
"image/gif",
"video/mp4",
"video/ogg",
"video/quicktime",
]:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)

View File

@@ -262,11 +262,8 @@ async def get_function_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -290,11 +287,8 @@ async def update_function_valves_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -353,11 +347,8 @@ async def get_function_user_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
@@ -377,11 +368,8 @@ async def update_function_user_valves_by_id(
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves

View File

@@ -333,10 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
@@ -450,7 +451,7 @@ def load_url_image_data(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
@@ -459,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@@ -526,7 +527,7 @@ async def image_generations(
else:
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images

View File

@@ -10,7 +10,7 @@ from open_webui.models.knowledge import (
KnowledgeUserResponse,
)
from open_webui.models.files import Files, FileModel, FileMetadataResponse
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import (
process_file,
ProcessFileForm,

View File

@@ -4,7 +4,7 @@ import logging
from typing import Optional
from open_webui.models.memories import Memories, MemoryModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.utils.auth import get_verified_user
from open_webui.env import SRC_LOG_LEVELS

View File

@@ -9,6 +9,8 @@ import os
import random
import re
import time
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
@@ -300,6 +302,22 @@ async def update_config(
}
def merge_ollama_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
@@ -340,6 +358,8 @@ async def get_all_models(request: Request, user: UserModel = None):
), # Legacy support
)
connection_type = api_config.get("connection_type", "local")
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
model_ids = api_config.get("model_ids", [])
@@ -352,31 +372,18 @@ async def get_all_models(request: Request, user: UserModel = None):
)
)
if prefix_id:
for model in response.get("models", []):
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
if tags:
for model in response.get("models", []):
if tags:
model["tags"] = tags
def merge_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
if connection_type:
model["connection_type"] = connection_type
models = {
"models": merge_models_lists(
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
@@ -384,6 +391,19 @@ async def get_all_models(request: Request, user: UserModel = None):
)
}
loaded_models = await get_ollama_loaded_models(request, user=user)
expires_map = {
m["name"]: m["expires_at"]
for m in loaded_models["models"]
if "expires_at" in m
}
for m in models["models"]:
if m["name"] in expires_map:
# Parse ISO8601 datetime with offset, get unix timestamp as int
dt = datetime.fromisoformat(expires_map[m["name"]])
m["expires_at"] = int(dt.timestamp())
else:
models = {"models": []}
@@ -464,6 +484,68 @@ async def get_ollama_tags(
return models
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(f"{url}/api/ps", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
models = {
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
else:
models = {"models": []}
return models
@router.get("/api/version")
@router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
@@ -537,36 +619,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
return {"version": False}
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [
send_get_request(
f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
class ModelNameForm(BaseModel):
name: str
@router.post("/api/unload")
async def unload_model(
request: Request,
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
model_name = form_data.name
if not model_name:
raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
# Canonicalize model name (if not supplied with version)
if ":" not in model_name:
model_name = f"{model_name}:latest"
if model_name not in models:
raise HTTPException(
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
)
url_indices = models[model_name]["urls"]
# Send unload to ALL url_indices
results = []
errors = []
for idx in url_indices:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
)
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id and model_name.startswith(f"{prefix_id}."):
model_name = model_name[len(f"{prefix_id}.") :]
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
try:
res = await send_post_request(
url=f"{url}/api/generate",
payload=json.dumps(payload),
stream=False,
key=key,
user=user,
)
results.append({"url_idx": idx, "success": True, "response": res})
except Exception as e:
log.exception(f"Failed to unload model on node {idx}: {e}")
errors.append({"url_idx": idx, "success": False, "error": str(e)})
if len(errors) > 0:
raise HTTPException(
status_code=500,
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
)
return {"status": True}
@router.post("/api/pull")
@router.post("/api/pull/{url_idx}")
async def pull_model(
@@ -1585,7 +1705,9 @@ async def upload_model(
if url_idx is None:
url_idx = 0
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = os.path.join(UPLOAD_DIR, file.filename)
filename = os.path.basename(file.filename)
file_path = os.path.join(UPLOAD_DIR, filename)
os.makedirs(UPLOAD_DIR, exist_ok=True)
# --- P1: save file locally ---
@@ -1630,13 +1752,13 @@ async def upload_model(
os.remove(file_path)
# Create model in ollama
model_name, ext = os.path.splitext(file.filename)
model_name, ext = os.path.splitext(filename)
log.info(f"Created Model: {model_name}") # DEBUG
create_payload = {
"model": model_name,
# Reference the file by its original name => the uploaded blob's digest
"files": {file.filename: f"sha256:{file_hash}"},
"files": {filename: f"sha256:{file_hash}"},
}
log.info(f"Model Payload: {create_payload}") # DEBUG
@@ -1653,7 +1775,7 @@ async def upload_model(
done_msg = {
"done": True,
"blob": f"sha256:{file_hash}",
"name": file.filename,
"name": filename,
"model_created": model_name,
}
yield f"data: {json.dumps(done_msg)}\n\n"

View File

@@ -353,21 +353,22 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
), # Legacy support
)
connection_type = api_config.get("connection_type", "external")
prefix_id = api_config.get("prefix_id", None)
tags = api_config.get("tags", [])
if prefix_id:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
for model in (
response if isinstance(response, list) else response.get("data", [])
):
if prefix_id:
model["id"] = f"{prefix_id}.{model['id']}"
if tags:
for model in (
response if isinstance(response, list) else response.get("data", [])
):
if tags:
model["tags"] = tags
if connection_type:
model["connection_type"] = connection_type
log.debug(f"get_all_models:responses() {responses}")
return responses
@@ -415,6 +416,7 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"connection_type": model.get("connection_type", "external"),
"urlIdx": idx,
}
for model in models
@@ -461,60 +463,74 @@ async def get_models(
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(url_idx),
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
)
r = None
async with aiohttp.ClientSession(
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
response_data = await r.json()
if api_config.get("azure", False):
models = {
"data": api_config.get("model_ids", []) or [],
"object": "list",
}
else:
headers["Authorization"] = f"Bearer {key}"
# Check if we're calling OpenAI API based on the URL
if "api.openai.com" in url:
# Filter models according to the specified conditions
response_data["data"] = [
model
for model in response_data.get("data", [])
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
async with session.get(
f"{url}/models",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
models = response_data
response_data = await r.json()
# Check if we're calling OpenAI API based on the URL
if "api.openai.com" in url:
# Filter models according to the specified conditions
response_data["data"] = [
model
for model in response_data.get("data", [])
if not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
models = response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
@@ -536,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
url: str
key: str
config: Optional[dict] = None
@router.post("/verify")
async def verify_connection(
@@ -544,39 +562,64 @@ async def verify_connection(
url = form_data.url
key = form_data.key
api_config = form_data.config or {}
async with aiohttp.ClientSession(
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
response_data = await r.json()
return response_data
if api_config.get("azure", False):
headers["api-key"] = key
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
async with session.get(
url=f"{url}/openai/models?api-version={api_version}",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
else:
headers["Authorization"] = f"Bearer {key}"
async with session.get(
f"{url}/models",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
res = await r.json()
if "error" in res:
error_detail = f"External Error: {res['error']}"
raise Exception(error_detail)
response_data = await r.json()
return response_data
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
@@ -590,6 +633,63 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
def convert_to_azure_payload(
url,
payload: dict,
):
model = payload.get("model", "")
# Filter allowed parameters based on Azure OpenAI API
allowed_params = {
"messages",
"temperature",
"role",
"content",
"contentPart",
"contentPartImage",
"enhancements",
"dataSources",
"n",
"stream",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
"max_completion_tokens",
}
# Special handling for o-series models
if model.startswith("o") and model.endswith("-mini"):
# Convert max_tokens to max_completion_tokens for o-series models
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Remove temperature if not 1 for o-series models
if "temperature" in payload and payload["temperature"] != 1:
log.debug(
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
)
del payload["temperature"]
# Filter out unsupported parameters
payload = {k: v for k, v in payload.items() if k in allowed_params}
url = f"{url}/openai/deployments/{model}"
return url, payload
@router.post("/chat/completions")
async def generate_chat_completion(
request: Request,
@@ -690,6 +790,38 @@ async def generate_chat_completion(
convert_logit_bias_input_to_json(payload["logit_bias"])
)
headers = {
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False):
request_url, payload = convert_to_azure_payload(url, payload)
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
headers["api-key"] = key
headers["api-version"] = api_version
request_url = f"{request_url}/chat/completions?api-version={api_version}"
else:
request_url = f"{url}/chat/completions"
headers["Authorization"] = f"Bearer {key}"
payload = json.dumps(payload)
r = None
@@ -704,30 +836,9 @@ async def generate_chat_completion(
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
url=request_url,
data=payload,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
@@ -783,31 +894,53 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
str(idx),
request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
), # Legacy support
)
r = None
session = None
streaming = False
try:
headers = {
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
}
if api_config.get("azure", False):
headers["api-key"] = key
headers["api-version"] = (
api_config.get("api_version", "") or "2023-03-15-preview"
)
payload = json.loads(body)
url, payload = convert_to_azure_payload(url, payload)
body = json.dumps(payload).encode()
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
else:
headers["Authorization"] = f"Bearer {key}"
request_url = f"{url}/{path}"
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=f"{url}/{path}",
url=request_url,
data=body,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
r.raise_for_status()

View File

@@ -18,7 +18,7 @@ from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional
from open_webui.env import SRC_LOG_LEVELS
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
@@ -69,7 +69,10 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
try:
urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@@ -89,6 +92,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
@@ -118,7 +122,10 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
try:
urlIdx = int(urlIdx)
except:
continue
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
@@ -138,6 +145,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json=request_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
payload = await response.json()
response.raise_for_status()
@@ -197,8 +205,10 @@ async def upload_pipeline(
user=Depends(get_admin_user),
):
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
filename = os.path.basename(file.filename)
# Check if the uploaded file is a python file
if not (file.filename and file.filename.endswith(".py")):
if not (filename and filename.endswith(".py")):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
@@ -206,7 +216,7 @@ async def upload_pipeline(
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
file_path = os.path.join(upload_folder, filename)
r = None
try:

View File

@@ -36,7 +36,7 @@ from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
# Document loaders
from open_webui.retrieval.loaders.main import Loader
@@ -352,10 +352,13 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -371,6 +374,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
# File upload settings
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
# Integration settings
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
@@ -383,6 +387,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -435,6 +440,7 @@ class WebConfig(BaseModel):
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
SEARXNG_QUERY_URL: Optional[str] = None
YACY_QUERY_URL: Optional[str] = None
YACY_USERNAME: Optional[str] = None
@@ -492,10 +498,14 @@ class ConfigForm(BaseModel):
# Content extraction settings
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
PDF_EXTRACT_IMAGES: Optional[bool] = None
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
TIKA_SERVER_URL: Optional[str] = None
DOCLING_SERVER_URL: Optional[str] = None
DOCLING_OCR_ENGINE: Optional[str] = None
DOCLING_OCR_LANG: Optional[str] = None
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
MISTRAL_OCR_API_KEY: Optional[str] = None
@@ -514,6 +524,7 @@ class ConfigForm(BaseModel):
# File upload settings
FILE_MAX_SIZE: Optional[int] = None
FILE_MAX_COUNT: Optional[int] = None
ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None
# Integration settings
ENABLE_GOOGLE_DRIVE_INTEGRATION: Optional[bool] = None
@@ -581,6 +592,16 @@ async def update_rag_config(
if form_data.PDF_EXTRACT_IMAGES is not None
else request.app.state.config.PDF_EXTRACT_IMAGES
)
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
form_data.EXTERNAL_DOCUMENT_LOADER_URL
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
)
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = (
form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY
if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
)
request.app.state.config.TIKA_SERVER_URL = (
form_data.TIKA_SERVER_URL
if form_data.TIKA_SERVER_URL is not None
@@ -601,6 +622,13 @@ async def update_rag_config(
if form_data.DOCLING_OCR_LANG is not None
else request.app.state.config.DOCLING_OCR_LANG
)
request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = (
form_data.DOCLING_DO_PICTURE_DESCRIPTION
if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
@@ -688,6 +716,11 @@ async def update_rag_config(
if form_data.FILE_MAX_COUNT is not None
else request.app.state.config.FILE_MAX_COUNT
)
request.app.state.config.ALLOWED_FILE_EXTENSIONS = (
form_data.ALLOWED_FILE_EXTENSIONS
if form_data.ALLOWED_FILE_EXTENSIONS is not None
else request.app.state.config.ALLOWED_FILE_EXTENSIONS
)
# Integration settings
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
@@ -720,6 +753,9 @@ async def update_rag_config(
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
)
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
@@ -809,10 +845,13 @@ async def update_rag_config(
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -828,6 +867,7 @@ async def update_rag_config(
# File upload settings
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
# Integration settings
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
@@ -840,6 +880,7 @@ async def update_rag_config(
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -1129,10 +1170,13 @@ def process_file(
file_path = Storage.get_file(file_path)
loader = Loader(
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
@@ -1640,13 +1684,29 @@ async def process_web_search(
)
try:
loader = get_web_loader(
urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
docs = [
Document(
page_content=result.snippet,
metadata={
"source": result.link,
"title": result.title,
"snippet": result.snippet,
"link": result.link,
},
)
for result in search_results
if hasattr(result, "snippet")
]
else:
loader = get_web_loader(
urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
urls = [
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
] # only keep the urls returned by the loader

View File

@@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
)
from open_webui.utils.task import get_task_model_id
from open_webui.config import (
@@ -195,15 +192,19 @@ async def generate_title(
},
)
max_tokens = (
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
**(
{"max_tokens": 1000}
{"max_tokens": max_tokens}
if models[task_model_id].get("owned_by") == "ollama"
else {
"max_completion_tokens": 1000,
"max_completion_tokens": max_tokens,
}
),
"metadata": {

View File

@@ -13,6 +13,8 @@ import pytz
from pytz import UTC
from typing import Optional, Union, List, Dict
from opentelemetry import trace
from open_webui.models.users import Users
from open_webui.constants import ERROR_MESSAGES
@@ -194,7 +196,17 @@ def get_current_user(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
return get_current_user_by_api_key(token)
user = get_current_user_by_api_key(token)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
return user
# auth by jwt token
try:
@@ -213,6 +225,14 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "jwt")
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
if background_tasks:
@@ -234,6 +254,14 @@ def get_current_user_by_api_key(api_key: str):
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
return user

View File

@@ -309,6 +309,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
metadata = {
"chat_id": data["chat_id"],
"message_id": data["id"],
"filter_ids": data.get("filter_ids", []),
"session_id": data["session_id"],
"user_id": user.id,
}
@@ -330,7 +331,9 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
result, _ = await process_filter_functions(
@@ -389,11 +392,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
}
)
if action_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)

View File

@@ -9,7 +9,18 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_sorted_filter_ids(model: dict):
def get_function_module(request, function_id):
"""
Get the function module by its ID.
"""
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
return function_module
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None:
@@ -21,14 +32,23 @@ def get_sorted_filter_ids(model: dict):
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
active_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
for filter_id in active_filter_ids:
function_module = get_function_module(request, filter_id)
if getattr(function_module, "toggle", None) and (
filter_id not in enabled_filter_ids
):
active_filter_ids.remove(filter_id)
continue
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority)
return filter_ids
@@ -43,12 +63,7 @@ async def process_filter_functions(
if not filter:
continue
if filter_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
request.app.state.FUNCTIONS[filter_id] = function_module
function_module = get_function_module(request, filter_id)
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:

View File

@@ -43,6 +43,7 @@ from open_webui.routers.pipelines import (
process_pipeline_outlet_filter,
)
from open_webui.routers.files import upload_file
from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook
@@ -253,7 +254,12 @@ async def chat_completion_tools_handler(
"name": (f"TOOL:{tool_name}"),
},
"document": [tool_result],
"metadata": [{"source": (f"TOOL:{tool_name}")}],
"metadata": [
{
"source": (f"TOOL:{tool_name}"),
"parameters": tool_function_params,
}
],
}
)
else:
@@ -292,6 +298,38 @@ async def chat_completion_tools_handler(
return body, {"sources": sources}
async def chat_memory_handler(
request: Request, form_data: dict, extra_params: dict, user
):
results = await query_memory(
request,
QueryMemoryForm(
**{"content": get_last_user_message(form_data["messages"]), "k": 3}
),
user,
)
user_context = ""
if results and hasattr(results, "documents"):
if results.documents and len(results.documents) > 0:
for doc_idx, doc in enumerate(results.documents[0]):
created_at_date = "Unknown Date"
if results.metadatas[0][doc_idx].get("created_at"):
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
created_at_date = time.strftime(
"%Y-%m-%d", time.localtime(created_at_timestamp)
)
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
form_data["messages"] = add_or_update_system_message(
f"User Context:\n{user_context}\n", form_data["messages"], append=True
)
return form_data
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
@@ -342,6 +380,11 @@ async def chat_web_search_handler(
log.exception(e)
queries = [user_message]
# Check if generated queries are empty
if len(queries) == 1 and queries[0].strip() == "":
queries = [user_message]
# Check if queries are not found
if len(queries) == 0:
await event_emitter(
{
@@ -653,7 +696,7 @@ def apply_params_to_form_data(form_data, model):
convert_logit_bias_input_to_json(params["logit_bias"])
)
except Exception as e:
print(f"Error parsing logit_bias: {e}")
log.exception(f"Error parsing logit_bias: {e}")
return form_data
@@ -751,9 +794,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
raise e
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
form_data, flags = await process_filter_functions(
@@ -768,6 +814,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
features = form_data.pop("features", None)
if features:
if "memory" in features and features["memory"]:
form_data = await chat_memory_handler(
request, form_data, extra_params, user
)
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
@@ -870,6 +921,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for doc_context, doc_meta in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
citation_id = (
doc_meta.get("source", None)
or source.get("source", {}).get("id", None)
@@ -877,7 +929,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
)
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
context_string += (
f'<source id="{citation_idx[citation_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{doc_context}</source>\n"
)
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -944,21 +1000,36 @@ async def process_chat_response(
message = message_map.get(metadata["message_id"]) if message_map else None
if message:
messages = get_message_list(message_map, message.get("id"))
message_list = get_message_list(message_map, message.get("id"))
# Remove reasoning details and files from the messages.
# Remove details tags and files from the messages.
# as get_message_list creates a new list, it does not affect
# the original messages outside of this handler
for message in messages:
message["content"] = re.sub(
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
"",
message["content"],
flags=re.S,
).strip()
if message.get("files"):
message["files"] = []
messages = []
for message in message_list:
content = message.get("content", "")
if isinstance(content, list):
for item in content:
if item.get("type") == "text":
content = item["text"]
break
if isinstance(content, str):
content = re.sub(
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
"",
content,
flags=re.S | re.I,
).strip()
messages.append(
{
**message,
"role": message["role"],
"content": content,
}
)
if tasks and messages:
if TASKS.TITLE_GENERATION in tasks:
@@ -1171,7 +1242,9 @@ async def process_chat_response(
}
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(model)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
# Streaming response

View File

@@ -130,7 +130,9 @@ def prepend_to_first_user_message_content(
return messages
def add_or_update_system_message(content: str, messages: list[dict]):
def add_or_update_system_message(
content: str, messages: list[dict], append: bool = False
):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
@@ -141,7 +143,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
if append:
messages[0]["content"] = f"{messages[0]['content']}\n{content}"
else:
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})

View File

@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"connection_type": model.get("connection_type", "local"),
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
@@ -110,6 +111,14 @@ async def get_all_models(request, user: UserModel = None):
for function in Functions.get_functions_by_type("action", active_only=True)
]
global_filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id is None:
@@ -125,13 +134,20 @@ async def get_all_models(request, user: UserModel = None):
model["name"] = custom_model.name
model["info"] = custom_model.model_dump()
# Set action_ids and filter_ids
action_ids = []
filter_ids = []
if "info" in model and "meta" in model["info"]:
action_ids.extend(
model["info"]["meta"].get("actionIds", [])
)
filter_ids.extend(
model["info"]["meta"].get("filterIds", [])
)
model["action_ids"] = action_ids
model["filter_ids"] = filter_ids
else:
models.remove(model)
@@ -140,7 +156,9 @@ async def get_all_models(request, user: UserModel = None):
):
owned_by = "openai"
pipe = None
action_ids = []
filter_ids = []
for model in models:
if (
@@ -154,9 +172,13 @@ async def get_all_models(request, user: UserModel = None):
if custom_model.meta:
meta = custom_model.meta.model_dump()
if "actionIds" in meta:
action_ids.extend(meta["actionIds"])
if "filterIds" in meta:
filter_ids.extend(meta["filterIds"])
models.append(
{
"id": f"{custom_model.id}",
@@ -168,6 +190,7 @@ async def get_all_models(request, user: UserModel = None):
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
"action_ids": action_ids,
"filter_ids": filter_ids,
}
)
@@ -181,8 +204,11 @@ async def get_all_models(request, user: UserModel = None):
"id": f"{function.id}.{action['id']}",
"name": action.get("name", f"{function.name} ({action['id']})"),
"description": function.meta.description,
"icon_url": action.get(
"icon_url", function.meta.manifest.get("icon_url", None)
"icon": action.get(
"icon_url",
function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None)
or getattr(module, "icon", None),
),
}
for action in actions
@@ -193,16 +219,28 @@ async def get_all_models(request, user: UserModel = None):
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon_url": function.meta.manifest.get("icon_url", None),
"icon": function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None)
or getattr(module, "icon", None),
}
]
# Process filter_ids to get the filters
def get_filter_items_from_module(function, module):
return [
{
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon": function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None)
or getattr(module, "icon", None),
}
]
def get_function_module_by_id(function_id):
if function_id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[function_id]
else:
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
return function_module
for model in models:
@@ -211,6 +249,11 @@ async def get_all_models(request, user: UserModel = None):
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
if action_id in enabled_action_ids
]
filter_ids = [
filter_id
for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
if filter_id in enabled_filter_ids
]
model["actions"] = []
for action_id in action_ids:
@@ -222,6 +265,20 @@ async def get_all_models(request, user: UserModel = None):
model["actions"].extend(
get_action_items_from_module(action_function, function_module)
)
model["filters"] = []
for filter_id in filter_ids:
filter_function = Functions.get_function_by_id(filter_id)
if filter_function is None:
raise Exception(f"Filter not found: {filter_id}")
function_module = get_function_module_by_id(filter_id)
if getattr(function_module, "toggle", None):
model["filters"].extend(
get_filter_items_from_module(filter_function, function_module)
)
log.debug(f"get_all_models() returned {len(models)} models")
request.app.state.MODELS = {model["id"]: model for model in models}

View File

@@ -41,6 +41,7 @@ from open_webui.config import (
)
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
WEBUI_NAME,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
@@ -305,8 +306,10 @@ class OAuthManager:
get_kwargs["headers"] = {
"Authorization": f"Bearer {access_token}",
}
async with aiohttp.ClientSession() as session:
async with session.get(picture_url, **get_kwargs) as resp:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as resp:
if resp.ok:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
@@ -371,7 +374,9 @@ class OAuthManager:
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/user/emails", headers=headers
"https://api.github.com/user/emails",
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as resp:
if resp.ok:
emails = await resp.json()
@@ -531,5 +536,10 @@ class OAuthManager:
secure=WEBUI_AUTH_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
if redirect_base_url.endswith("/"):
redirect_base_url = redirect_base_url[:-1]
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
return RedirectResponse(url=redirect_url, headers=response.headers)

View File

@@ -57,6 +57,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = {
"temperature": float,
"top_p": float,
"min_p": float,
"max_tokens": int,
"frequency_penalty": float,
"presence_penalty": float,

View File

@@ -22,7 +22,7 @@ def get_task_model_id(
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id].get("owned_by") == "ollama":
if models[task_model_id].get("connection_type") == "local":
if task_model and task_model in models:
task_model_id = task_model
else:

View File

@@ -37,6 +37,7 @@ from open_webui.models.tools import Tools
from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tool_module_by_id
from open_webui.env import (
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
)
@@ -44,6 +45,7 @@ from open_webui.env import (
import copy
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
def get_async_tool_function_and_apply_extra_params(
@@ -158,7 +160,7 @@ def get_tools(
# TODO: Fix hack for OpenAI API
# Some times breaks OpenAI but others don't. Leaving the comment
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
if val.get("type") == "str":
val["type"] = "string"
# Remove internal reserved parameters (e.g. __id__, __user__)
@@ -477,7 +479,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
"specs": convert_openapi_to_tool_payload(res),
}
print("Fetched data:", data)
log.info("Fetched data:", data)
return data
@@ -510,7 +512,7 @@ async def get_tool_servers_data(
results = []
for (idx, server, url, _), response in zip(server_entries, responses):
if isinstance(response, Exception):
print(f"Failed to connect to {url} OpenAPI tool server")
log.error(f"Failed to connect to {url} OpenAPI tool server")
continue
results.append(
@@ -620,5 +622,5 @@ async def execute_tool_server(
except Exception as err:
error = str(err)
print("API Request Error:", error)
log.exception("API Request Error:", error)
return {"error": error}

View File

@@ -15,7 +15,7 @@ aiofiles
sqlalchemy==2.0.38
alembic==1.14.0
peewee==3.17.9
peewee==3.18.1
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
pgvector==0.4.0
@@ -37,7 +37,8 @@ asgiref==3.8.1
# AI libraries
openai
anthropic
google-generativeai==0.8.4
google-genai==1.15.0
google-generativeai==0.8.5
tiktoken
langchain==0.3.24
@@ -98,7 +99,7 @@ pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=8.0.0
duckduckgo-search==8.0.2
## Google Drive
google-api-python-client