Merge remote-tracking branch 'upstream/dev' into Individual-RAG-Config

This commit is contained in:
Maytown
2025-05-02 20:38:35 +02:00
161 changed files with 6412 additions and 2966 deletions

View File

@@ -76,7 +76,7 @@ def serve(
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(
open_webui.main.app,
"open_webui.main:app",
host=host,
port=port,
forwarded_allow_ips="*",

View File

@@ -509,6 +509,12 @@ ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
)
ENABLE_OAUTH_GROUP_CREATION = PersistentConfig(
"ENABLE_OAUTH_GROUP_CREATION",
"oauth.enable_group_creation",
os.environ.get("ENABLE_OAUTH_GROUP_CREATION", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
@@ -952,10 +958,15 @@ DEFAULT_MODELS = PersistentConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
)
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
"DEFAULT_PROMPT_SUGGESTIONS",
"ui.prompt_suggestions",
[
try:
default_prompt_suggestions = json.loads(
os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")
)
except Exception as e:
log.exception(f"Error loading DEFAULT_PROMPT_SUGGESTIONS: {e}")
default_prompt_suggestions = []
if default_prompt_suggestions == []:
default_prompt_suggestions = [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
@@ -983,7 +994,11 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
],
]
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
"DEFAULT_PROMPT_SUGGESTIONS",
"ui.prompt_suggestions",
default_prompt_suggestions,
)
MODEL_ORDER_LIST = PersistentConfig(
@@ -1062,6 +1077,14 @@ USER_PERMISSIONS_CHAT_EDIT = (
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_SHARE = (
os.environ.get("USER_PERMISSIONS_CHAT_SHARE", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_EXPORT = (
os.environ.get("USER_PERMISSIONS_CHAT_EXPORT", "True").lower() == "true"
)
USER_PERMISSIONS_CHAT_STT = (
os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true"
)
@@ -1126,6 +1149,8 @@ DEFAULT_USER_PERMISSIONS = {
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
"delete": USER_PERMISSIONS_CHAT_DELETE,
"edit": USER_PERMISSIONS_CHAT_EDIT,
"share": USER_PERMISSIONS_CHAT_SHARE,
"export": USER_PERMISSIONS_CHAT_EXPORT,
"stt": USER_PERMISSIONS_CHAT_STT,
"tts": USER_PERMISSIONS_CHAT_TTS,
"call": USER_PERMISSIONS_CHAT_CALL,
@@ -1153,6 +1178,11 @@ ENABLE_CHANNELS = PersistentConfig(
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
)
ENABLE_NOTES = PersistentConfig(
"ENABLE_NOTES",
"notes.enable",
os.environ.get("ENABLE_NOTES", "True").lower() == "true",
)
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
"ENABLE_EVALUATION_ARENA_MODELS",
@@ -1203,6 +1233,9 @@ ENABLE_USER_WEBHOOKS = PersistentConfig(
os.environ.get("ENABLE_USER_WEBHOOKS", "True").lower() == "true",
)
# FastAPI / AnyIO settings
THREAD_POOL_SIZE = int(os.getenv("THREAD_POOL_SIZE", "0"))
def validate_cors_origins(origins):
for origin in origins:
@@ -1229,7 +1262,9 @@ def validate_cors_origin(origin):
# To test CORS_ALLOW_ORIGIN locally, you can set something like
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
# in your .env file depending on your frontend port, 5173 in this case.
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
CORS_ALLOW_ORIGIN = os.environ.get(
"CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080"
).split(";")
if "*" in CORS_ALLOW_ORIGIN:
log.warning(
@@ -1693,6 +1728,9 @@ MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
# Qdrant
QDRANT_URI = os.environ.get("QDRANT_URI", None)
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
# OpenSearch
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
@@ -1724,6 +1762,14 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
)
# Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "open-webui-index")
PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024, 768
PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure"
####################################
# Information Retrieval (RAG)
####################################
@@ -1760,6 +1806,13 @@ ONEDRIVE_CLIENT_ID = PersistentConfig(
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
)
ONEDRIVE_SHAREPOINT_URL = PersistentConfig(
"ONEDRIVE_SHAREPOINT_URL",
"onedrive.sharepoint_url",
os.environ.get("ONEDRIVE_SHAREPOINT_URL", ""),
)
# RAG Content Extraction
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
@@ -2092,6 +2145,24 @@ SEARXNG_QUERY_URL = PersistentConfig(
os.getenv("SEARXNG_QUERY_URL", ""),
)
YACY_QUERY_URL = PersistentConfig(
"YACY_QUERY_URL",
"rag.web.search.yacy_query_url",
os.getenv("YACY_QUERY_URL", ""),
)
YACY_USERNAME = PersistentConfig(
"YACY_USERNAME",
"rag.web.search.yacy_username",
os.getenv("YACY_USERNAME", ""),
)
YACY_PASSWORD = PersistentConfig(
"YACY_PASSWORD",
"rag.web.search.yacy_password",
os.getenv("YACY_PASSWORD", ""),
)
GOOGLE_PSE_API_KEY = PersistentConfig(
"GOOGLE_PSE_API_KEY",
"rag.web.search.google_pse_api_key",
@@ -2256,6 +2327,29 @@ FIRECRAWL_API_BASE_URL = PersistentConfig(
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
)
EXTERNAL_WEB_SEARCH_URL = PersistentConfig(
"EXTERNAL_WEB_SEARCH_URL",
"rag.web.search.external_web_search_url",
os.environ.get("EXTERNAL_WEB_SEARCH_URL", ""),
)
EXTERNAL_WEB_SEARCH_API_KEY = PersistentConfig(
"EXTERNAL_WEB_SEARCH_API_KEY",
"rag.web.search.external_web_search_api_key",
os.environ.get("EXTERNAL_WEB_SEARCH_API_KEY", ""),
)
EXTERNAL_WEB_LOADER_URL = PersistentConfig(
"EXTERNAL_WEB_LOADER_URL",
"rag.web.loader.external_web_loader_url",
os.environ.get("EXTERNAL_WEB_LOADER_URL", ""),
)
EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
"EXTERNAL_WEB_LOADER_API_KEY",
"rag.web.loader.external_web_loader_api_key",
os.environ.get("EXTERNAL_WEB_LOADER_API_KEY", ""),
)
####################################
# Images
@@ -2566,6 +2660,18 @@ AUDIO_STT_AZURE_LOCALES = PersistentConfig(
os.getenv("AUDIO_STT_AZURE_LOCALES", ""),
)
AUDIO_STT_AZURE_BASE_URL = PersistentConfig(
"AUDIO_STT_AZURE_BASE_URL",
"audio.stt.azure.base_url",
os.getenv("AUDIO_STT_AZURE_BASE_URL", ""),
)
AUDIO_STT_AZURE_MAX_SPEAKERS = PersistentConfig(
"AUDIO_STT_AZURE_MAX_SPEAKERS",
"audio.stt.azure.max_speakers",
os.getenv("AUDIO_STT_AZURE_MAX_SPEAKERS", "3"),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",

View File

@@ -354,6 +354,10 @@ BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
)
WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get(
"WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None
)
####################################
# WEBUI_SECRET_KEY
####################################
@@ -409,6 +413,11 @@ else:
except Exception:
AIOHTTP_CLIENT_TIMEOUT = 300
AIOHTTP_CLIENT_SESSION_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true"
)
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
@@ -437,6 +446,56 @@ else:
except Exception:
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
)
####################################
# SENTENCE TRANSFORMERS
####################################
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
if SENTENCE_TRANSFORMERS_BACKEND == "":
SENTENCE_TRANSFORMERS_BACKEND = "torch"
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
)
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
else:
try:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
)
except Exception:
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
####################################
# OFFLINE_MODE
####################################
@@ -446,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
if OFFLINE_MODE:
os.environ["HF_HUB_OFFLINE"] = "1"
####################################
# AUDIT LOGGING
####################################
@@ -467,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# OPENTELEMETRY
####################################

View File

@@ -17,6 +17,7 @@ from sqlalchemy import text
from typing import Optional
from aiocache import cached
import aiohttp
import anyio.to_thread
import requests
@@ -100,11 +101,14 @@ from open_webui.config import (
# OpenAI
ENABLE_OPENAI_API,
ONEDRIVE_CLIENT_ID,
ONEDRIVE_SHAREPOINT_URL,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
# Direct Connections
ENABLE_DIRECT_CONNECTIONS,
# Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE,
# Tool Server Configs
TOOL_SERVER_CONNECTIONS,
# Code Execution
@@ -151,6 +155,8 @@ from open_webui.config import (
AUDIO_STT_AZURE_API_KEY,
AUDIO_STT_AZURE_REGION,
AUDIO_STT_AZURE_LOCALES,
AUDIO_STT_AZURE_BASE_URL,
AUDIO_STT_AZURE_MAX_SPEAKERS,
AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
@@ -219,6 +225,9 @@ from open_webui.config import (
SERPAPI_API_KEY,
SERPAPI_ENGINE,
SEARXNG_QUERY_URL,
YACY_QUERY_URL,
YACY_USERNAME,
YACY_PASSWORD,
SERPER_API_KEY,
SERPLY_API_KEY,
SERPSTACK_API_KEY,
@@ -240,12 +249,17 @@ from open_webui.config import (
GOOGLE_DRIVE_CLIENT_ID,
GOOGLE_DRIVE_API_KEY,
ONEDRIVE_CLIENT_ID,
ONEDRIVE_SHAREPOINT_URL,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_WEB_LOADER_SSL_VERIFICATION,
ENABLE_GOOGLE_DRIVE_INTEGRATION,
ENABLE_ONEDRIVE_INTEGRATION,
UPLOAD_DIR,
EXTERNAL_WEB_SEARCH_URL,
EXTERNAL_WEB_SEARCH_API_KEY,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
# WebUI
WEBUI_AUTH,
WEBUI_NAME,
@@ -260,6 +274,7 @@ from open_webui.config import (
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
API_KEY_ALLOWED_ENDPOINTS,
ENABLE_CHANNELS,
ENABLE_NOTES,
ENABLE_COMMUNITY_SHARING,
ENABLE_MESSAGE_RATING,
ENABLE_USER_WEBHOOKS,
@@ -341,6 +356,7 @@ from open_webui.env import (
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
ENABLE_WEBSOCKET_SUPPORT,
BYPASS_MODEL_ACCESS_CONTROL,
RESET_CONFIG_ON_START,
@@ -370,6 +386,7 @@ from open_webui.utils.auth import (
get_admin_user,
get_verified_user,
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
@@ -432,7 +449,18 @@ async def lifespan(app: FastAPI):
if LICENSE_KEY:
get_license_data(app, LICENSE_KEY)
# This should be blocking (sync) so functions are not deactivated on first /get_models calls
# when the first user lands on the / route.
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
pool_size = THREAD_POOL_SIZE
if pool_size and pool_size > 0:
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = pool_size
asyncio.create_task(periodic_usage_pool_cleanup())
yield
@@ -543,6 +571,7 @@ app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
app.state.config.ENABLE_NOTES = ENABLE_NOTES
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_USER_WEBHOOKS = ENABLE_USER_WEBHOOKS
@@ -576,6 +605,7 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL
app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
app.state.USER_COUNT = None
@@ -646,6 +676,9 @@ app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
app.state.config.YACY_USERNAME = YACY_USERNAME
app.state.config.YACY_PASSWORD = YACY_PASSWORD
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
@@ -668,6 +701,10 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY
app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL
app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
@@ -796,6 +833,8 @@ app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
app.state.config.AUDIO_STT_AZURE_REGION = AUDIO_STT_AZURE_REGION
app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
@@ -869,7 +908,8 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Check for the specific watch path and the presence of 'v' parameter
if path.endswith("/watch") and "v" in query_params:
video_id = query_params["v"][0] # Extract the first 'v' parameter
# Extract the first 'v' parameter
video_id = query_params["v"][0]
encoded_video_id = urlencode({"youtube": video_id})
redirect_url = f"/?{encoded_video_id}"
return RedirectResponse(url=redirect_url)
@@ -1283,6 +1323,7 @@ async def get_app_config(request: Request):
{
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
"enable_channels": app.state.config.ENABLE_CHANNELS,
"enable_notes": app.state.config.ENABLE_NOTES,
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
"enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION,
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
@@ -1327,7 +1368,10 @@ async def get_app_config(request: Request):
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
"api_key": GOOGLE_DRIVE_API_KEY.value,
},
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
"onedrive": {
"client_id": ONEDRIVE_CLIENT_ID.value,
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
},
"license_metadata": app.state.LICENSE_METADATA,
**(
{
@@ -1439,7 +1483,7 @@ async def get_manifest_json():
"start_url": "/",
"display": "standalone",
"background_color": "#343541",
"orientation": "natural",
"orientation": "any",
"icons": [
{
"src": "/static/logo.png",

View File

@@ -10,6 +10,8 @@ from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import or_
####################
# User DB Schema
@@ -67,6 +69,11 @@ class UserModel(BaseModel):
####################
class UserListResponse(BaseModel):
users: list[UserModel]
total: int
class UserResponse(BaseModel):
id: str
name: str
@@ -160,11 +167,63 @@ class UsersTable:
return None
def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
self,
filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> UserListResponse:
with get_db() as db:
query = db.query(User)
query = db.query(User).order_by(User.created_at.desc())
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
)
)
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "email":
if direction == "asc":
query = query.order_by(User.email.asc())
else:
query = query.order_by(User.email.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(User.created_at.asc())
else:
query = query.order_by(User.created_at.desc())
elif order_by == "last_active_at":
if direction == "asc":
query = query.order_by(User.last_active_at.asc())
else:
query = query.order_by(User.last_active_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(User.updated_at.asc())
else:
query = query.order_by(User.updated_at.desc())
elif order_by == "role":
if direction == "asc":
query = query.order_by(User.role.asc())
else:
query = query.order_by(User.role.desc())
else:
query = query.order_by(User.created_at.desc())
if skip:
query = query.offset(skip)
@@ -172,8 +231,10 @@ class UsersTable:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
return {
"users": [UserModel.model_validate(user) for user in users],
"total": db.query(User).count(),
}
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:

View File

@@ -0,0 +1,53 @@
import requests
import logging
from typing import Iterator, List, Union
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalLoader(BaseLoader):
def __init__(
self,
web_paths: Union[str, List[str]],
external_url: str,
external_api_key: str,
continue_on_failure: bool = True,
**kwargs,
) -> None:
self.external_url = external_url
self.external_api_key = external_api_key
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
self.continue_on_failure = continue_on_failure
def lazy_load(self) -> Iterator[Document]:
batch_size = 20
for i in range(0, len(self.urls), batch_size):
urls = self.urls[i : i + batch_size]
try:
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {self.external_api_key}",
},
json={
"urls": urls,
},
)
response.raise_for_status()
results = response.json()
for result in results:
yield Document(
page_content=result.get("page_content", ""),
metadata=result.get("metadata", {}),
)
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}")
else:
raise e

View File

@@ -207,7 +207,7 @@ def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
doc_hash = hashlib.md5(
doc_hash = hashlib.sha256(
document.encode()
).hexdigest() # Compute a hash for uniqueness
@@ -260,23 +260,47 @@ def query_collection(
k: int,
) -> dict:
results = []
for query in queries:
log.debug(f"query_collection:query {query}")
query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
for collection_name in collection_names:
error = False
def process_query_collection(collection_name, query_embedding):
try:
if collection_name:
try:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
results.append(result.model_dump())
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
else:
pass
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
log.exception(f"Error when querying the collection: {e}")
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
result = executor.submit(
process_query_collection, collection_name, query_embedding
)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
log.warning("All collection queries failed. No results returned.")
return merge_and_sort_query_results(results, k=k)

View File

@@ -20,6 +20,10 @@ elif VECTOR_DB == "elasticsearch":
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
VECTOR_DB_CLIENT = ElasticsearchClient()
elif VECTOR_DB == "pinecone":
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
VECTOR_DB_CLIENT = PineconeClient()
else:
from open_webui.retrieval.vector.dbs.chroma import ChromaClient

View File

@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ChromaClient:
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,

View File

@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
@@ -15,7 +20,7 @@ from open_webui.config import (
)
class ElasticsearchClient:
class ElasticsearchClient(VectorDBBase):
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating

View File

@@ -4,7 +4,12 @@ import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class MilvusClient:
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:

View File

@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
@@ -12,7 +17,7 @@ from open_webui.config import (
)
class OpenSearchClient:
class OpenSearchClient(VectorDBBase):
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(

View File

@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.env import SRC_LOG_LEVELS
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient:
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
@@ -136,9 +141,8 @@ class PgvectorClient:
# Pad the vector with zeros
vector += [0.0] * (VECTOR_LENGTH - current_length)
elif current_length > VECTOR_LENGTH:
raise Exception(
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
)
# Truncate the vector to VECTOR_LENGTH
vector = vector[:VECTOR_LENGTH]
return vector
def insert(self, collection_name: str, items: List[VectorItem]) -> None:

View File

@@ -0,0 +1,412 @@
from typing import Optional, List, Dict, Any, Union
import logging
from pinecone import Pinecone, ServerlessSpec
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_ENVIRONMENT,
PINECONE_INDEX_NAME,
PINECONE_DIMENSION,
PINECONE_METRIC,
PINECONE_CLOUD,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
# Validate required configuration
self._validate_config()
# Store configuration values
self.api_key = PINECONE_API_KEY
self.environment = PINECONE_ENVIRONMENT
self.index_name = PINECONE_INDEX_NAME
self.dimension = PINECONE_DIMENSION
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client
self.client = Pinecone(api_key=self.api_key)
# Create index if it doesn't exist
self._initialize_index()
def _validate_config(self) -> None:
"""Validate that all required configuration variables are set."""
missing_vars = []
if not PINECONE_API_KEY:
missing_vars.append("PINECONE_API_KEY")
if not PINECONE_ENVIRONMENT:
missing_vars.append("PINECONE_ENVIRONMENT")
if not PINECONE_INDEX_NAME:
missing_vars.append("PINECONE_INDEX_NAME")
if not PINECONE_DIMENSION:
missing_vars.append("PINECONE_DIMENSION")
if not PINECONE_CLOUD:
missing_vars.append("PINECONE_CLOUD")
if missing_vars:
raise ValueError(
f"Required configuration missing: {', '.join(missing_vars)}"
)
def _initialize_index(self) -> None:
"""Initialize the Pinecone index."""
try:
# Check if index exists
if self.index_name not in self.client.list_indexes().names():
log.info(f"Creating Pinecone index '{self.index_name}'...")
self.client.create_index(
name=self.index_name,
dimension=self.dimension,
metric=self.metric,
spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
)
log.info(f"Successfully created Pinecone index '{self.index_name}'")
else:
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(self.index_name)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
"""Convert VectorItem objects to Pinecone point format."""
points = []
for item in items:
# Start with any existing metadata or an empty dict
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
# Add text to metadata if available
if "text" in item:
metadata["text"] = item["text"]
# Always add collection_name to metadata for filtering
metadata["collection_name"] = collection_name_with_prefix
point = {
"id": item["id"],
"values": item["vector"],
"metadata": metadata,
}
points.append(point)
return points
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
"""Get the collection name with prefix."""
return f"{self.collection_prefix}_{collection_name}"
def _normalize_distance(self, score: float) -> float:
"""Normalize distance score based on the metric used."""
if self.metric.lower() == "cosine":
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
return (score + 1.0) / 2.0
elif self.metric.lower() in ["euclidean", "dotproduct"]:
# These are already suitable for ranking (smaller is better for Euclidean)
return score
else:
# For other metrics, use as is
return score
def _result_to_get_result(self, matches: list) -> GetResult:
"""Convert Pinecone matches to GetResult format."""
ids = []
documents = []
metadatas = []
for match in matches:
metadata = match.get("metadata", {})
ids.append(match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def has_collection(self, collection_name: str) -> bool:
"""Check if a collection exists by searching for at least one item."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Search for at least 1 item with this collection name in metadata
response = self.index.query(
vector=[0.0] * self.dimension, # dummy vector
top_k=1,
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
return len(response.matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
)
return False
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection by removing all vectors with the collection name in metadata."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
self.index.delete(filter={"collection_name": collection_name_with_prefix})
log.info(
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
)
except Exception as e:
log.warning(
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
)
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert vectors into a collection."""
if not items:
log.warning("No items to insert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Insert in batches for better performance and reliability
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
try:
self.index.upsert(vectors=batch)
log.debug(
f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
)
except Exception as e:
log.error(
f"Error inserting batch into '{collection_name_with_prefix}': {e}"
)
raise
log.info(
f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Upsert (insert or update) vectors into a collection."""
if not items:
log.warning("No items to upsert")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Upsert in batches
for i in range(0, len(points), BATCH_SIZE):
batch = points[i : i + BATCH_SIZE]
try:
self.index.upsert(vectors=batch)
log.debug(
f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
)
except Exception as e:
log.error(
f"Error upserting batch into '{collection_name_with_prefix}': {e}"
)
raise
log.info(
f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
if not vectors or not vectors[0]:
log.warning("No vectors provided for search")
return None
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Search using the first vector (assuming this is the intended behavior)
query_vector = vectors[0]
# Perform the search
query_response = self.index.query(
vector=query_vector,
top_k=limit,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
if not query_response.matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
documents=[[]],
metadatas=[[]],
distances=[[]],
)
# Convert to GetResult format
get_result = self._result_to_get_result(query_response.matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(match.score)
for match in query_response.matches
]
]
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=distances,
)
except Exception as e:
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors by metadata filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
if limit is None or limit <= 0:
limit = NO_LIMIT
try:
# Create a zero vector for the dimension as Pinecone requires a vector
zero_vector = [0.0] * self.dimension
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Perform metadata-only query
query_response = self.index.query(
vector=zero_vector,
filter=pinecone_filter,
top_k=limit,
include_metadata=True,
)
return self._result_to_get_result(query_response.matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
return None
def get(self, collection_name: str) -> Optional[GetResult]:
"""Get all vectors in a collection."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
# Use a zero vector for fetching all entries
zero_vector = [0.0] * self.dimension
# Add filter to only get vectors for this collection
query_response = self.index.query(
vector=zero_vector,
top_k=NO_LIMIT,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
)
return self._result_to_get_result(query_response.matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
return None
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by IDs or filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
try:
if ids:
# Delete by IDs (in batches for large deletions)
for i in range(0, len(ids), BATCH_SIZE):
batch_ids = ids[i : i + BATCH_SIZE]
# Note: When deleting by ID, we can't filter by collection_name
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
)
elif filter:
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Delete by metadata filter
self.index.delete(filter=pinecone_filter)
log.info(
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
)
else:
log.warning("No ids or filter provided for delete operation")
except Exception as e:
log.error(f"Error deleting from collection '{collection_name}': {e}")
raise
def reset(self) -> None:
"""Reset the database by deleting all collections."""
try:
self.index.delete(delete_all=True)
log.info("All vectors successfully deleted from the index.")
except Exception as e:
log.error(f"Failed to reset Pinecone index: {e}")
raise

View File

@@ -1,12 +1,24 @@
from typing import Optional
import logging
from urllib.parse import urlparse
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
)
from open_webui.env import SRC_LOG_LEVELS
NO_LIMIT = 999999999
@@ -15,16 +27,34 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class QdrantClient:
class QdrantClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.QDRANT_URI = QDRANT_URI
self.QDRANT_API_KEY = QDRANT_API_KEY
self.client = (
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
if self.QDRANT_URI
else None
)
self.QDRANT_ON_DISK = QDRANT_ON_DISK
self.PREFER_GRPC = QDRANT_PREFER_GRPC
self.GRPC_PORT = QDRANT_GRPC_PORT
if not self.QDRANT_URI:
self.client = None
return
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC:
self.client = Qclient(
host=host,
port=http_port,
grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY,
)
else:
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
def _result_to_get_result(self, points) -> GetResult:
ids = []
@@ -50,7 +80,9 @@ class QdrantClient:
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
size=dimension,
distance=models.Distance.COSINE,
on_disk=self.QDRANT_ON_DISK,
),
)

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel
from typing import Optional, List, Any
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class VectorItem(BaseModel):
@@ -17,3 +18,69 @@ class GetResult(BaseModel):
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]
class VectorDBBase(ABC):
"""
Abstract base class for all vector database backends.
Implementations of this class provide methods for collection management,
vector insertion, deletion, similarity search, and metadata filtering.
Any custom vector database integration must inherit from this class and
implement all abstract methods.
"""
@abstractmethod
def has_collection(self, collection_name: str) -> bool:
"""Check if the collection exists in the vector DB."""
pass
@abstractmethod
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection from the vector DB."""
pass
@abstractmethod
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert a list of vector items into a collection."""
pass
@abstractmethod
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert or update vector items in a collection."""
pass
@abstractmethod
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
pass
@abstractmethod
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
"""Query vectors from a collection using metadata filter."""
pass
@abstractmethod
def get(self, collection_name: str) -> Optional[GetResult]:
"""Retrieve all vectors from a collection."""
pass
@abstractmethod
def delete(
self,
collection_name: str,
ids: Optional[List[str]] = None,
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by ID or filter from a collection."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by removing all collections or those matching a condition."""
pass

View File

@@ -0,0 +1,47 @@
import logging
from typing import Optional, List
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_external(
external_url: str,
external_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
response = requests.post(
external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
},
json={
"query": query,
"count": count,
},
)
response.raise_for_status()
results = response.json()
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("link"),
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View File

@@ -0,0 +1,49 @@
import logging
from typing import Optional, List
from urllib.parse import urljoin
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_firecrawl(
firecrawl_url: str,
firecrawl_api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> List[SearchResult]:
try:
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
response = requests.post(
firecrawl_search_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {firecrawl_api_key}",
},
json={
"query": query,
"limit": count,
},
)
response.raise_for_status()
results = response.json().get("data", [])
if filter_list:
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
return results
except Exception as e:
log.error(f"Error in External search: {e}")
return []

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
import requests
from open_webui.retrieval.web.main import SearchResult
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -21,18 +21,25 @@ def search_tavily(
Args:
api_key (str): A Tavily Search API key
query (str): The query to search for
count (int): The maximum number of results to return
Returns:
list[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
data = {"query": query, "max_results": count}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
json_response = response.json()
raw_search_results = json_response.get("results", [])
results = json_response.get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
@@ -40,5 +47,5 @@ def search_tavily(
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in raw_search_results[:count]
for result in results
]

View File

@@ -25,6 +25,7 @@ from langchain_community.document_loaders.firecrawl import FireCrawlLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external import ExternalLoader
from open_webui.constants import ERROR_MESSAGES
from open_webui.config import (
ENABLE_RAG_LOCAL_WEB_FETCH,
@@ -35,6 +36,8 @@ from open_webui.config import (
FIRECRAWL_API_KEY,
TAVILY_API_KEY,
TAVILY_EXTRACT_DEPTH,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -167,7 +170,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
continue_on_failure: bool = True,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
mode: Literal["crawl", "scrape", "map"] = "crawl",
mode: Literal["crawl", "scrape", "map"] = "scrape",
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
@@ -225,7 +228,10 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
mode=self.mode,
params=self.params,
)
yield from loader.lazy_load()
for document in loader.lazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
@@ -245,6 +251,8 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
params=self.params,
)
async for document in loader.alazy_load():
if not document.metadata.get("source"):
document.metadata["source"] = document.metadata.get("sourceURL")
yield document
except Exception as e:
if self.continue_on_failure:
@@ -619,6 +627,11 @@ def get_web_loader(
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
if WEB_LOADER_ENGINE.value == "external":
WebLoaderClass = ExternalLoader
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
if WebLoaderClass:
web_loader = WebLoaderClass(**web_loader_args)

View File

@@ -0,0 +1,85 @@
import logging
from typing import Optional
import requests
from requests.auth import HTTPDigestAuth
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_yacy(
query_url: str,
username: Optional[str],
password: Optional[str],
query: str,
count: int,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""
Search a Yacy instance for a given query and return the results as a list of SearchResult objects.
The function accepts username and password for authenticating to Yacy.
Args:
query_url (str): The base URL of the Yacy server.
username (str): Optional YaCy username.
password (str): Optional YaCy password.
query (str): The search term or question to find in the Yacy database.
count (int): The maximum number of results to retrieve from the search.
Returns:
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Use authentication if either username or password is set
yacy_auth = None
if username or password:
yacy_auth = HTTPDigestAuth(username, password)
params = {
"query": query,
"contentdom": "text",
"resource": "global",
"maximumRecords": count,
"nav": "none",
}
# Check if provided a json API URL
if not query_url.endswith("yacysearch.json"):
# Strip all query parameters from the URL
query_url = query_url.rstrip('/') + "/yacysearch.json"
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
auth=yacy_auth,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("channels", [{}])[0].get("items", [])
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["link"], title=result.get("title"), snippet=result.get("description")
)
for result in sorted_results[:count]
]

View File

@@ -150,7 +150,8 @@ class STTConfigForm(BaseModel):
AZURE_API_KEY: str
AZURE_REGION: str
AZURE_LOCALES: str
AZURE_BASE_URL: str
AZURE_MAX_SPEAKERS: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
@@ -181,6 +182,8 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
},
}
@@ -210,6 +213,8 @@ async def update_audio_config(
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = form_data.stt.AZURE_MAX_SPEAKERS
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -238,6 +243,8 @@ async def update_audio_config(
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
},
}
@@ -641,6 +648,8 @@ def transcribe(request: Request, file_path):
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
region = request.app.state.config.AUDIO_STT_AZURE_REGION
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS
# IF NO LOCALES, USE DEFAULTS
if len(locales) < 2:
@@ -664,7 +673,13 @@ def transcribe(request: Request, file_path):
if not api_key or not region:
raise HTTPException(
status_code=400,
detail="Azure API key and region are required for Azure STT",
detail="Azure API key is required for Azure STT",
)
if not base_url and not region:
raise HTTPException(
status_code=400,
detail="Azure region or base url is required for Azure STT",
)
r = None
@@ -674,13 +689,14 @@ def transcribe(request: Request, file_path):
"definition": json.dumps(
{
"locales": locales.split(","),
"diarization": {"maxSpeakers": 3, "enabled": True},
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
}
if locales
else {}
)
}
url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
url = base_url or f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
# Use context manager to ensure file is properly closed
with open(file_path, "rb") as audio_file:

View File

@@ -27,20 +27,24 @@ from open_webui.env import (
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
SRC_LOG_LEVELS,
)
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse, Response
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
from pydantic import BaseModel
from open_webui.utils.misc import parse_duration, validate_email_format
from open_webui.utils.auth import (
decode_token,
create_api_key,
create_token,
get_admin_user,
get_verified_user,
get_current_user,
get_password_hash,
get_http_authorization_cred,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions
@@ -72,27 +76,29 @@ class SessionUserResponse(Token, UserResponse):
async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
):
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=expires_delta,
)
auth_header = request.headers.get("Authorization")
auth_token = get_http_authorization_cred(auth_header)
token = auth_token.credentials
data = decode_token(token)
datetime_expires_at = (
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
)
expires_at = data.get("exp")
if (expires_at is not None) and int(time.time()) > expires_at:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=datetime_expires_at,
expires=(
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
@@ -288,18 +294,30 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
user = Auths.authenticate_user_by_trusted_header(email)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
expires_at = None
if expires_delta:
expires_at = int(time.time()) + int(expires_delta.total_seconds())
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(
request.app.state.config.JWT_EXPIRES_IN
),
expires_delta=expires_delta,
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
expires=(
datetime.datetime.fromtimestamp(
expires_at, datetime.timezone.utc
)
if expires_at
else None
),
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
@@ -309,6 +327,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
return {
"token": token,
"token_type": "Bearer",
"expires_at": expires_at,
"id": user.id,
"email": user.email,
"name": user.name,
@@ -566,6 +585,12 @@ async def signout(request: Request, response: Response):
detail="Failed to sign out from the OpenID provider.",
)
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
return RedirectResponse(
headers=response.headers,
url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
)
return {"status": True}
@@ -664,6 +689,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
}
@@ -680,6 +706,7 @@ class AdminConfig(BaseModel):
ENABLE_COMMUNITY_SHARING: bool
ENABLE_MESSAGE_RATING: bool
ENABLE_CHANNELS: bool
ENABLE_NOTES: bool
ENABLE_USER_WEBHOOKS: bool
@@ -700,6 +727,7 @@ async def update_admin_config(
)
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
@@ -724,11 +752,12 @@ async def update_admin_config(
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
}

View File

@@ -638,8 +638,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if not has_permission(
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)

View File

@@ -19,6 +19,8 @@ from fastapi import (
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.users import Users
from open_webui.models.files import (
FileForm,
FileModel,
@@ -83,10 +85,12 @@ def upload_file(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = {},
file_metadata: dict = None,
process: bool = Query(True),
):
log.info(f"file.content_type: {file.content_type}")
file_metadata = file_metadata if file_metadata else {}
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
@@ -95,7 +99,13 @@ def upload_file(
id = str(uuid.uuid4())
name = filename
filename = f"{id}_{filename}"
contents, file_path = Storage.upload_file(file.file, filename)
tags = {
"OpenWebUI-User-Email": user.email,
"OpenWebUI-User-Id": user.id,
"OpenWebUI-User-Name": user.name,
"OpenWebUI-File-Id": id,
}
contents, file_path = Storage.upload_file(file.file, filename, tags)
file_item = Files.insert_new_file(
user.id,
@@ -129,7 +139,15 @@ def upload_file(
ProcessFileForm(file_id=id, content=result.get("text", "")),
user=user,
)
elif file.content_type not in ["image/png", "image/jpeg", "image/gif"]:
elif file.content_type not in [
"image/png",
"image/jpeg",
"image/gif",
"video/mp4",
"video/ogg",
"video/quicktime",
"video/webm",
]:
process_file(request, ProcessFileForm(file_id=id), user=user)
file_item = Files.get_file_by_id(id=id)
@@ -173,7 +191,8 @@ async def list_files(user=Depends(get_verified_user), content: bool = Query(True
if not content:
for file in files:
del file.data["content"]
if "content" in file.data:
del file.data["content"]
return files
@@ -214,7 +233,8 @@ async def search_files(
if not content:
for file in matching_files:
del file.data["content"]
if "content" in file.data:
del file.data["content"]
return matching_files
@@ -431,6 +451,13 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.NOT_FOUND,
)
file_user = Users.get_user_by_id(file.user_id)
if not file_user.role == "admin":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
file.user_id == user.id
or user.role == "admin"

View File

@@ -500,7 +500,11 @@ async def image_generations(
if form_data.size
else request.app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
**(
{"response_format": "b64_json"}
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
else {}
),
}
# Use asyncio.to_thread for the requests.post call

View File

@@ -10,7 +10,7 @@ from open_webui.models.knowledge import (
KnowledgeUserResponse,
RAGConfigForm
)
from open_webui.models.files import Files, FileModel
from open_webui.models.files import Files, FileModel, FileMetadataResponse
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import (
process_file,
@@ -179,10 +179,26 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
for knowledge_base in knowledge_bases:
try:
files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
deleted_knowledge_bases = []
for knowledge_base in knowledge_bases:
# -- Robust error handling for missing or invalid data
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
log.warning(
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
)
try:
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
deleted_knowledge_bases.append(knowledge_base.id)
except Exception as e:
log.error(
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
)
continue
try:
file_ids = knowledge_base.data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
try:
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
VECTOR_DB_CLIENT.delete_collection(
@@ -190,10 +206,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
)
except Exception as e:
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error deleting vector DB collection",
)
continue # Skip, don't raise
failed_files = []
for file in files:
@@ -214,10 +227,8 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
except Exception as e:
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing knowledge base",
)
# Don't raise, just continue
continue
if failed_files:
log.warning(
@@ -226,7 +237,9 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
for failed in failed_files:
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
log.info("Reindexing completed successfully")
log.info(
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
)
return True
@@ -236,7 +249,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
class KnowledgeFilesResponse(KnowledgeResponse):
files: list[FileModel]
files: list[FileMetadataResponse]
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
@@ -252,7 +265,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -380,7 +393,7 @@ def add_file_to_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -457,7 +470,7 @@ def update_file_from_knowledge_by_id(
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -539,7 +552,7 @@ def remove_file_from_knowledge_by_id(
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
if knowledge:
files = Files.get_files_by_ids(file_ids)
files = Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@@ -735,7 +748,7 @@ def add_files_to_knowledge_batch(
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
files=Files.get_file_metadatas_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details,
@@ -743,5 +756,6 @@ def add_files_to_knowledge_batch(
)
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
**knowledge.model_dump(),
files=Files.get_file_metadatas_by_ids(existing_file_ids),
)

View File

@@ -54,6 +54,7 @@ from open_webui.config import (
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL,
@@ -91,6 +92,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
return await response.json()
except Exception as e:
@@ -141,6 +143,7 @@ async def send_post_request(
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
r.raise_for_status()
@@ -216,7 +219,8 @@ async def verify_connection(
key = form_data.key
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
@@ -234,6 +238,7 @@ async def verify_connection(
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
detail = f"HTTP Error: {r.status}"
@@ -1006,7 +1011,7 @@ class GenerateCompletionForm(BaseModel):
prompt: str
suffix: Optional[str] = None
images: Optional[list[str]] = None
format: Optional[str] = None
format: Optional[Union[dict, str]] = None
options: Optional[dict] = None
system: Optional[str] = None
template: Optional[str] = None
@@ -1482,7 +1487,9 @@ async def download_file_stream(
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(file_url, headers=headers) as response:
async with session.get(
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
@@ -1497,7 +1504,8 @@ async def download_file_stream(
if done:
file.seek(0)
hashed = calculate_sha256(file)
chunk_size = 1024 * 1024 * 2
hashed = calculate_sha256(file, chunk_size)
file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"

View File

@@ -21,6 +21,7 @@ from open_webui.config import (
CACHE_DIR,
)
from open_webui.env import (
AIOHTTP_CLIENT_SESSION_SSL,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
ENABLE_FORWARD_USER_INFO_HEADERS,
@@ -74,6 +75,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
return await response.json()
except Exception as e:
@@ -92,20 +94,19 @@ async def cleanup_response(
await session.close()
def openai_o1_o3_handler(payload):
def openai_o_series_handler(payload):
"""
Handle o1, o3 specific parameters
Handle "o" series specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: o1 and o3 do not support the "system" role directly.
# For older models like "o1-mini" or "o1-preview", use role "user".
# For newer o1/o3 models, replace "system" with "developer".
# Handle system role conversion based on model type
if payload["messages"][0]["role"] == "system":
model_lower = payload["model"].lower()
# Legacy models use "user" role instead of "system"
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
payload["messages"][0]["role"] = "user"
else:
@@ -462,7 +463,8 @@ async def get_models(
r = None
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
@@ -481,6 +483,7 @@ async def get_models(
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
@@ -542,7 +545,8 @@ async def verify_connection(
key = form_data.key
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
trust_env=True,
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
) as session:
try:
async with session.get(
@@ -561,6 +565,7 @@ async def verify_connection(
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
# Extract response error details if available
@@ -666,10 +671,10 @@ async def generate_chat_completion(
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
if is_o1_o3:
payload = openai_o1_o3_handler(payload)
# Check if model is from "o" series
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
if is_o_series:
payload = openai_o_series_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload for backward compatibility
if "max_completion_tokens" in payload:
@@ -723,6 +728,7 @@ async def generate_chat_completion(
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
# Check if response is SSE
@@ -802,6 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
else {}
),
},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
)
r.raise_for_status()

View File

@@ -66,7 +66,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
if "pipeline" in model:
sorted_filters.append(model)
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:
@@ -115,7 +115,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
for filter in sorted_filters:
urlIdx = filter.get("urlIdx")
if urlIdx is None:

View File

@@ -53,6 +53,7 @@ from open_webui.retrieval.web.jina_search import search_jina
from open_webui.retrieval.web.searchapi import search_searchapi
from open_webui.retrieval.web.serpapi import search_serpapi
from open_webui.retrieval.web.searxng import search_searxng
from open_webui.retrieval.web.yacy import search_yacy
from open_webui.retrieval.web.serper import search_serper
from open_webui.retrieval.web.serply import search_serply
from open_webui.retrieval.web.serpstack import search_serpstack
@@ -61,6 +62,8 @@ from open_webui.retrieval.web.bing import search_bing
from open_webui.retrieval.web.exa import search_exa
from open_webui.retrieval.web.perplexity import search_perplexity
from open_webui.retrieval.web.sougou import search_sougou
from open_webui.retrieval.web.firecrawl import search_firecrawl
from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import (
get_embedding_function,
@@ -90,7 +93,12 @@ from open_webui.env import (
SRC_LOG_LEVELS,
DEVICE_TYPE,
DOCKER,
SENTENCE_TRANSFORMERS_BACKEND,
SENTENCE_TRANSFORMERS_MODEL_KWARGS,
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
from open_webui.constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
@@ -117,6 +125,8 @@ def get_ef(
get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
)
except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}")
@@ -150,6 +160,8 @@ def get_rf(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
except Exception as e:
log.error(f"CrossEncoder: {e}")
@@ -460,6 +472,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
"WEB_SEARCH_DOMAIN_FILTER_LIST": rag_config.get("web_search_domain_filter_list", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_web_search_embedding_and_retrieval", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
"SEARXNG_QUERY_URL": rag_config.get("searxng_query_url", request.app.state.config.SEARXNG_QUERY_URL),
"YACY_QUERY_URL": rag_config.get("yacy_query_url", request.app.state.config.YACY_QUERY_URL),
"YACY_USERNAME": rag_config.get("yacy_query_username",request.app.state.config.YACY_USERNAME),
"YACY_PASSWORD": rag_config.get("yacy_query_password",request.app.state.config.YACY_PASSWORD),
"GOOGLE_PSE_API_KEY": rag_config.get("google_pse_api_key", request.app.state.config.GOOGLE_PSE_API_KEY),
"GOOGLE_PSE_ENGINE_ID": rag_config.get("google_pse_engine_id", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
"BRAVE_SEARCH_API_KEY": rag_config.get("brave_search_api_key", request.app.state.config.BRAVE_SEARCH_API_KEY),
@@ -489,6 +504,10 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
"FIRECRAWL_API_KEY": rag_config.get("firecrawl_api_key", request.app.state.config.FIRECRAWL_API_KEY),
"FIRECRAWL_API_BASE_URL": rag_config.get("firecrawl_api_base_url", request.app.state.config.FIRECRAWL_API_BASE_URL),
"TAVILY_EXTRACT_DEPTH": rag_config.get("tavily_extract_depth", request.app.state.config.TAVILY_EXTRACT_DEPTH),
"EXTERNAL_WEB_SEARCH_URL": rag_config.get("web_search_url", request.app.state.config.EXTERNAL_WEB_SEARCH_URL),
"EXTERNAL_WEB_SEARCH_API_KEY": rag_config.get("web_search_key", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY),
"EXTERNAL_WEB_LOADER_URL": rag_config.get("web_loader_url", request.app.state.config.EXTERNAL_WEB_LOADER_URL),
"EXTERNAL_WEB_LOADER_API_KEY": rag_config.get("web_loader_key", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY),
"YOUTUBE_LOADER_LANGUAGE": rag_config.get("youtube_loader_language", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
"YOUTUBE_LOADER_PROXY_URL": rag_config.get("youtube_loader_proxy_url", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
"YOUTUBE_LOADER_TRANSLATION": rag_config.get("youtube_loader_translation", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
@@ -535,6 +554,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
@@ -564,6 +586,10 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
@@ -580,6 +606,9 @@ class WebConfig(BaseModel):
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
SEARXNG_QUERY_URL: Optional[str] = None
YACY_QUERY_URL: Optional[str] = None
YACY_USERNAME: Optional[str] = None
YACY_PASSWORD: Optional[str] = None
GOOGLE_PSE_API_KEY: Optional[str] = None
GOOGLE_PSE_ENGINE_ID: Optional[str] = None
BRAVE_SEARCH_API_KEY: Optional[str] = None
@@ -609,6 +638,10 @@ class WebConfig(BaseModel):
FIRECRAWL_API_KEY: Optional[str] = None
FIRECRAWL_API_BASE_URL: Optional[str] = None
TAVILY_EXTRACT_DEPTH: Optional[str] = None
EXTERNAL_WEB_SEARCH_URL: Optional[str] = None
EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None
EXTERNAL_WEB_LOADER_URL: Optional[str] = None
EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None
YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None
YOUTUBE_LOADER_PROXY_URL: Optional[str] = None
YOUTUBE_LOADER_TRANSLATION: Optional[str] = None
@@ -668,9 +701,9 @@ async def update_rag_config(
rag_config = knowledge_base.data.get("rag_config", {})
# Update only the provided fields in the rag_config
for field, value in form_data.dict(exclude_unset=True).items():
for field, value in form_data.model_dump(exclude_unset=True).items():
if field == "web" and value is not None:
rag_config["web"] = {**rag_config.get("web", {}), **value.dict(exclude_unset=True)}
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
else:
rag_config[field] = value
@@ -709,6 +742,7 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
@@ -821,6 +855,9 @@ async def update_rag_config(
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD
request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY
request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
form_data.web.GOOGLE_PSE_ENGINE_ID
@@ -867,6 +904,18 @@ async def update_rag_config(
request.app.state.config.FIRECRAWL_API_BASE_URL = (
form_data.web.FIRECRAWL_API_BASE_URL
)
request.app.state.config.EXTERNAL_WEB_SEARCH_URL = (
form_data.web.EXTERNAL_WEB_SEARCH_URL
)
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = (
form_data.web.EXTERNAL_WEB_SEARCH_API_KEY
)
request.app.state.config.EXTERNAL_WEB_LOADER_URL = (
form_data.web.EXTERNAL_WEB_LOADER_URL
)
request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = (
form_data.web.EXTERNAL_WEB_LOADER_API_KEY
)
request.app.state.config.TAVILY_EXTRACT_DEPTH = (
form_data.web.TAVILY_EXTRACT_DEPTH
)
@@ -919,7 +968,10 @@ async def update_rag_config(
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
@@ -948,7 +1000,11 @@ async def update_rag_config(
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
},
@@ -1491,6 +1547,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- YACY_QUERY_URL + YACY_USERNAME + YACY_PASSWORD
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- KAGI_SEARCH_API_KEY
@@ -1520,6 +1577,18 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "yacy":
if request.app.state.config.YACY_QUERY_URL:
return search_yacy(
request.app.state.config.YACY_QUERY_URL,
request.app.state.config.YACY_USERNAME,
request.app.state.config.YACY_PASSWORD,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No YACY_QUERY_URL found in environment variables")
elif engine == "google_pse":
if (
request.app.state.config.GOOGLE_PSE_API_KEY
@@ -1690,6 +1759,22 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
raise Exception(
"No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables"
)
elif engine == "firecrawl":
return search_firecrawl(
request.app.state.config.FIRECRAWL_API_BASE_URL,
request.app.state.config.FIRECRAWL_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "external":
return search_external(
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No search engine API key found in environment variables")
@@ -1702,8 +1787,11 @@ async def process_web_search(
logging.info(
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web(
request, request.app.state.config.WEB_SEARCH_ENGINE, form_data.query
web_results = await run_in_threadpool(
search_web,
request,
request.app.state.config.WEB_SEARCH_ENGINE,
form_data.query,
)
except Exception as e:
log.exception(e)
@@ -1725,8 +1813,8 @@ async def process_web_search(
)
docs = await loader.aload()
urls = [
doc.metadata["source"] for doc in docs
] # only keep URLs which could be retrieved
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
] # only keep URLs
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
return {
@@ -1746,19 +1834,22 @@ async def process_web_search(
collection_names = []
for doc_idx, doc in enumerate(docs):
if doc and doc.page_content:
collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
:63
]
try:
collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
:63
]
collection_names.append(collection_name)
await run_in_threadpool(
save_docs_to_vector_db,
request,
[doc],
collection_name,
overwrite=True,
user=user,
)
collection_names.append(collection_name)
await run_in_threadpool(
save_docs_to_vector_db,
request,
[doc],
collection_name,
overwrite=True,
user=user,
)
except Exception as e:
log.debug(f"error saving doc {doc_idx}: {e}")
return {
"status": True,

View File

@@ -6,6 +6,7 @@ from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.users import (
UserModel,
UserListResponse,
UserRoleUpdateForm,
Users,
UserSettings,
@@ -33,13 +34,38 @@ router = APIRouter()
############################
@router.get("/", response_model=list[UserModel])
PAGE_ITEM_COUNT = 10
@router.get("/", response_model=UserListResponse)
async def get_users(
skip: Optional[int] = None,
limit: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_admin_user),
):
return Users.get_users(skip, limit)
limit = PAGE_ITEM_COUNT
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserListResponse)
async def get_all_users(
user=Depends(get_admin_user),
):
return Users.get_users()
############################
@@ -88,6 +114,8 @@ class ChatPermissions(BaseModel):
file_upload: bool = True
delete: bool = True
edit: bool = True
share: bool = True
export: bool = True
stt: bool = True
tts: bool = True
call: bool = True
@@ -288,6 +316,21 @@ async def update_user_by_id(
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
):
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id and session_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
user = Users.get_user_by_id(user_id)
if user:
@@ -335,6 +378,21 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
# Prevent deletion of the primary admin user
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not verify primary admin status.",
)
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
@@ -346,6 +404,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
# Prevent self-deletion
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,

View File

@@ -192,6 +192,9 @@ async def connect(sid, environ, auth):
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
await sio.emit("usage", {"models": get_models_in_use()})
return True
return False
@sio.on("user-join")
@@ -314,16 +317,18 @@ def get_event_emitter(request_info, update_db=True):
)
)
for session_id in session_ids:
await sio.emit(
"chat-events",
{
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
"data": event_data,
},
to=session_id,
)
emit_tasks = [sio.emit(
"chat-events",
{
"chat_id": request_info.get("chat_id", None),
"message_id": request_info.get("message_id", None),
"data": event_data,
},
to=session_id,
)
for session_id in session_ids]
await asyncio.gather(*emit_tasks)
if update_db:
if "type" in event_data and event_data["type"] == "status":

View File

@@ -3,7 +3,7 @@ import shutil
import json
import logging
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple
from typing import BinaryIO, Tuple, Dict
import boto3
from botocore.config import Config
@@ -44,7 +44,9 @@ class StorageProvider(ABC):
pass
@abstractmethod
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
pass
@abstractmethod
@@ -58,7 +60,9 @@ class StorageProvider(ABC):
class LocalStorageProvider(StorageProvider):
@staticmethod
def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
contents = file.read()
if not contents:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -131,12 +135,20 @@ class S3StorageProvider(StorageProvider):
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
_, file_path = LocalStorageProvider.upload_file(file, filename, tags)
tagging = {"TagSet": [{"Key": k, "Value": v} for k, v in tags.items()]}
try:
s3_key = os.path.join(self.key_prefix, filename)
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
self.s3_client.put_object_tagging(
Bucket=self.bucket_name,
Key=s3_key,
Tagging=tagging,
)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + s3_key,
@@ -207,9 +219,11 @@ class GCSStorageProvider(StorageProvider):
self.gcs_client = storage.Client()
self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to GCS storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename)
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
try:
blob = self.bucket.blob(filename)
blob.upload_from_filename(file_path)
@@ -277,9 +291,11 @@ class AzureStorageProvider(StorageProvider):
self.container_name
)
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
"""Handles uploading of the file to Azure Blob Storage."""
contents, file_path = LocalStorageProvider.upload_file(file, filename)
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
try:
blob_client = self.container_client.get_blob_client(filename)
blob_client.upload_blob(contents, overwrite=True)

View File

@@ -37,7 +37,7 @@ if TYPE_CHECKING:
class AuditLogEntry:
# `Metadata` audit level properties
id: str
user: dict[str, Any]
user: Optional[dict[str, Any]]
audit_level: str
verb: str
request_uri: str
@@ -190,21 +190,40 @@ class AuditLoggingMiddleware:
finally:
await self._log_audit_entry(request, context)
async def _get_authenticated_user(self, request: Request) -> UserModel:
async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
auth_header = request.headers.get("Authorization")
assert auth_header
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
return user
try:
user = get_current_user(
request, None, get_http_authorization_cred(auth_header)
)
return user
except Exception as e:
logger.debug(f"Failed to get authenticated user: {str(e)}")
return None
def _should_skip_auditing(self, request: Request) -> bool:
if (
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
or AUDIT_LOG_LEVEL == "NONE"
or not request.headers.get("authorization")
):
return True
ALWAYS_LOG_ENDPOINTS = {
"/api/v1/auths/signin",
"/api/v1/auths/signout",
"/api/v1/auths/signup",
}
path = request.url.path.lower()
for endpoint in ALWAYS_LOG_ENDPOINTS:
if path.startswith(endpoint):
return False # Do NOT skip logging for auth endpoints
# Skip logging if the request is not authenticated
if not request.headers.get("authorization"):
return True
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
pattern = re.compile(
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
@@ -231,17 +250,32 @@ class AuditLoggingMiddleware:
try:
user = await self._get_authenticated_user(request)
user = (
user.model_dump(include={"id", "name", "email", "role"}) if user else {}
)
request_body = context.request_body.decode("utf-8", errors="replace")
response_body = context.response_body.decode("utf-8", errors="replace")
# Redact sensitive information
if "password" in request_body:
request_body = re.sub(
r'"password":\s*"(.*?)"',
'"password": "********"',
request_body,
)
entry = AuditLogEntry(
id=str(uuid.uuid4()),
user=user.model_dump(include={"id", "name", "email", "role"}),
user=user,
audit_level=self.audit_level.value,
verb=request.method,
request_uri=str(request.url),
response_status_code=context.metadata.get("response_status_code", None),
source_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
request_object=context.request_body.decode("utf-8", errors="replace"),
response_object=context.response_body.decode("utf-8", errors="replace"),
request_object=request_body,
response_object=response_body,
)
self.audit_logger.write(entry)

View File

@@ -50,7 +50,7 @@ class JupyterCodeExecuter:
self.password = password
self.timeout = timeout
self.kernel_id = ""
self.session = aiohttp.ClientSession(base_url=self.base_url)
self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
self.params = {}
self.result = ResultModel()

View File

@@ -888,16 +888,20 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
citated_file_idx = {}
for _, source in enumerate(sources, 1):
citation_idx = {}
for source in sources:
if "document" in source:
for doc_context, doc_meta in zip(
source["document"], source["metadata"]
):
file_id = doc_meta.get("file_id")
if file_id not in citated_file_idx:
citated_file_idx[file_id] = len(citated_file_idx) + 1
context_string += f'<source id="{citated_file_idx[file_id]}">{doc_context}</source>\n'
citation_id = (
doc_meta.get("source", None)
or source.get("source", {}).get("id", None)
or "N/A"
)
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -1133,7 +1137,7 @@ async def process_chat_response(
)
# Send a webhook notification if the user is not active
if get_active_status_by_user_id(user.id) is None:
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
@@ -1671,6 +1675,15 @@ async def process_chat_response(
if current_response_tool_call is None:
# Add the new tool call
delta_tool_call.setdefault(
"function", {}
)
delta_tool_call[
"function"
].setdefault("name", "")
delta_tool_call[
"function"
].setdefault("arguments", "")
response_tool_calls.append(
delta_tool_call
)
@@ -2215,7 +2228,7 @@ async def process_chat_response(
)
# Send a webhook notification if the user is not active
if get_active_status_by_user_id(user.id) is None:
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(

View File

@@ -15,7 +15,7 @@ from starlette.responses import RedirectResponse
from open_webui.models.auths import Auths
from open_webui.models.users import Users
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
@@ -23,6 +23,7 @@ from open_webui.config import (
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
ENABLE_OAUTH_GROUP_MANAGEMENT,
ENABLE_OAUTH_GROUP_CREATION,
OAUTH_ROLES_CLAIM,
OAUTH_GROUPS_CLAIM,
OAUTH_EMAIL_CLAIM,
@@ -57,6 +58,7 @@ auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
@@ -152,6 +154,51 @@ class OAuthManager:
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
# Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
log.debug("Checking for missing groups to create...")
all_group_names = {g.name for g in all_available_groups}
groups_created = False
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
admin_user = Users.get_admin_user()
creator_id = admin_user.id if admin_user else user.id
log.debug(f"Using creator ID {creator_id} for potential group creation.")
for group_name in user_oauth_groups:
if group_name not in all_group_names:
log.info(
f"Group '{group_name}' not found via OAuth claim. Creating group..."
)
try:
new_group_form = GroupForm(
name=group_name,
description=f"Group '{group_name}' created automatically via OAuth.",
permissions=default_permissions, # Use default permissions from function args
user_ids=[], # Start with no users, user will be added later by subsequent logic
)
# Use determined creator ID (admin or fallback to current user)
created_group = Groups.insert_new_group(
creator_id, new_group_form
)
if created_group:
log.info(
f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
)
groups_created = True
# Add to local set to prevent duplicate creation attempts in this run
all_group_names.add(group_name)
else:
log.error(
f"Failed to create group '{group_name}' via OAuth."
)
except Exception as e:
log.error(f"Error creating group '{group_name}' via OAuth: {e}")
# Refresh the list of all available groups if any were created
if groups_created:
all_available_groups = Groups.get_groups()
log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}")
log.debug(f"User oauth groups: {user_oauth_groups}")
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
@@ -257,7 +304,7 @@ class OAuthManager:
try:
access_token = token.get("access_token")
headers = {"Authorization": f"Bearer {access_token}"}
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/user/emails", headers=headers
) as resp:
@@ -339,7 +386,7 @@ class OAuthManager:
get_kwargs["headers"] = {
"Authorization": f"Bearer {access_token}",
}
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
picture_url, **get_kwargs
) as resp:

View File

@@ -157,7 +157,8 @@ def load_function_module_by_id(function_id, content=None):
raise Exception("No Function class found in the module")
except Exception as e:
log.error(f"Error loading module: {function_id}: {e}")
del sys.modules[module_name] # Cleanup by removing the module in case of error
# Cleanup by removing the module in case of error
del sys.modules[module_name]
Functions.update_function_by_id(function_id, {"is_active": False})
raise e
@@ -182,3 +183,32 @@ def install_frontmatter_requirements(requirements: str):
else:
log.info("No requirements found in frontmatter.")
def install_tool_and_function_dependencies():
"""
Install all dependencies for all admin tools and active functions.
By first collecting all dependencies from the frontmatter of each tool and function,
and then installing them using pip. Duplicates or similar version specifications are
handled by pip as much as possible.
"""
function_list = Functions.get_functions(active_only=True)
tool_list = Tools.get_tools()
all_dependencies = ""
try:
for function in function_list:
frontmatter = extract_frontmatter(replace_imports(function.content))
if dependencies := frontmatter.get("requirements"):
all_dependencies += f"{dependencies}, "
for tool in tool_list:
# Only install requirements for admin tools
if tool.user.role == "admin":
frontmatter = extract_frontmatter(replace_imports(tool.content))
if dependencies := frontmatter.get("requirements"):
all_dependencies += f"{dependencies}, "
install_frontmatter_requirements(all_dependencies.strip(", "))
except Exception as e:
log.error(f"Error installing requirements: {e}")

View File

@@ -36,7 +36,10 @@ from langchain_core.utils.function_calling import (
from open_webui.models.tools import Tools
from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tool_module_by_id
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
)
import copy
@@ -276,8 +279,8 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
docstring = func.__doc__
description = parse_description(docstring)
function_descriptions = parse_docstring(docstring)
function_description = parse_description(docstring)
function_param_descriptions = parse_docstring(docstring)
field_defs = {}
for name, param in parameters.items():
@@ -285,15 +288,15 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
type_hint = type_hints.get(name, Any)
default_value = param.default if param.default is not param.empty else ...
description = function_descriptions.get(name, None)
param_description = function_param_descriptions.get(name, None)
if description:
field_defs[name] = type_hint, Field(default_value, description=description)
if param_description:
field_defs[name] = type_hint, Field(default_value, description=param_description)
else:
field_defs[name] = type_hint, default_value
model = create_model(func.__name__, **field_defs)
model.__doc__ = description
model.__doc__ = function_description
return model
@@ -371,51 +374,64 @@ def convert_openapi_to_tool_payload(openapi_spec):
for path, methods in openapi_spec.get("paths", {}).items():
for method, operation in methods.items():
tool = {
"type": "function",
"name": operation.get("operationId"),
"description": operation.get(
"description", operation.get("summary", "No description available.")
),
"parameters": {"type": "object", "properties": {}, "required": []},
}
# Extract path and query parameters
for param in operation.get("parameters", []):
param_name = param["name"]
param_schema = param.get("schema", {})
tool["parameters"]["properties"][param_name] = {
"type": param_schema.get("type"),
"description": param_schema.get("description", ""),
if operation.get("operationId"):
tool = {
"type": "function",
"name": operation.get("operationId"),
"description": operation.get(
"description",
operation.get("summary", "No description available."),
),
"parameters": {"type": "object", "properties": {}, "required": []},
}
if param.get("required"):
tool["parameters"]["required"].append(param_name)
# Extract and resolve requestBody if available
request_body = operation.get("requestBody")
if request_body:
content = request_body.get("content", {})
json_schema = content.get("application/json", {}).get("schema")
if json_schema:
resolved_schema = resolve_schema(
json_schema, openapi_spec.get("components", {})
)
if resolved_schema.get("properties"):
tool["parameters"]["properties"].update(
resolved_schema["properties"]
# Extract path and query parameters
for param in operation.get("parameters", []):
param_name = param["name"]
param_schema = param.get("schema", {})
description = param_schema.get("description", "")
if not description:
description = param.get("description") or ""
if param_schema.get("enum") and isinstance(
param_schema.get("enum"), list
):
description += (
f". Possible values: {', '.join(param_schema.get('enum'))}"
)
if "required" in resolved_schema:
tool["parameters"]["required"] = list(
set(
tool["parameters"]["required"]
+ resolved_schema["required"]
)
)
elif resolved_schema.get("type") == "array":
tool["parameters"] = resolved_schema # special case for array
tool["parameters"]["properties"][param_name] = {
"type": param_schema.get("type"),
"description": description,
}
if param.get("required"):
tool["parameters"]["required"].append(param_name)
tool_payload.append(tool)
# Extract and resolve requestBody if available
request_body = operation.get("requestBody")
if request_body:
content = request_body.get("content", {})
json_schema = content.get("application/json", {}).get("schema")
if json_schema:
resolved_schema = resolve_schema(
json_schema, openapi_spec.get("components", {})
)
if resolved_schema.get("properties"):
tool["parameters"]["properties"].update(
resolved_schema["properties"]
)
if "required" in resolved_schema:
tool["parameters"]["required"] = list(
set(
tool["parameters"]["required"]
+ resolved_schema["required"]
)
)
elif resolved_schema.get("type") == "array":
tool["parameters"] = (
resolved_schema # special case for array
)
tool_payload.append(tool)
return tool_payload
@@ -431,8 +447,10 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
error = None
try:
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
) as response:
if response.status != 200:
error_body = await response.json()
raise Exception(error_body)
@@ -573,19 +591,26 @@ async def execute_tool_server(
if token:
headers["Authorization"] = f"Bearer {token}"
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
request_method = getattr(session, http_method.lower())
if http_method in ["post", "put", "patch"]:
async with request_method(
final_url, json=body_params, headers=headers
final_url,
json=body_params,
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
) as response:
if response.status >= 400:
text = await response.text()
raise Exception(f"HTTP error {response.status}: {text}")
return await response.json()
else:
async with request_method(final_url, headers=headers) as response:
async with request_method(
final_url,
headers=headers,
ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
) as response:
if response.status >= 400:
text = await response.text()
raise Exception(f"HTTP error {response.status}: {text}")

View File

@@ -31,7 +31,7 @@ APScheduler==3.10.4
RestrictedPython==8.0
loguru==0.7.2
loguru==0.7.3
asgiref==3.8.1
# AI libraries
@@ -40,8 +40,8 @@ anthropic
google-generativeai==0.8.4
tiktoken
langchain==0.3.19
langchain-community==0.3.18
langchain==0.3.24
langchain-community==0.3.23
fake-useragent==2.1.0
chromadb==0.6.3
@@ -49,11 +49,11 @@ pymilvus==2.5.0
qdrant-client~=1.12.0
opensearch-py==2.8.0
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
elasticsearch==8.17.1
elasticsearch==9.0.1
pinecone==6.0.2
transformers
sentence-transformers==3.3.1
sentence-transformers==4.1.0
accelerate
colbert-ai==0.2.21
einops==0.8.1
@@ -81,7 +81,7 @@ azure-ai-documentintelligence==1.0.0
pillow==11.1.0
opencv-python-headless==4.11.0.86
rapidocr-onnxruntime==1.3.24
rapidocr-onnxruntime==1.4.4
rank-bm25==0.2.2
onnxruntime==1.20.1
@@ -107,7 +107,7 @@ google-auth-oauthlib
## Tests
docker~=7.1.0
pytest~=8.3.2
pytest~=8.3.5
pytest-docker~=3.1.1
googleapis-common-protos==1.63.2