mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into Individual-RAG-Config
This commit is contained in:
@@ -76,7 +76,7 @@ def serve(
|
||||
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
||||
|
||||
uvicorn.run(
|
||||
open_webui.main.app,
|
||||
"open_webui.main:app",
|
||||
host=host,
|
||||
port=port,
|
||||
forwarded_allow_ips="*",
|
||||
|
||||
@@ -509,6 +509,12 @@ ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_GROUP_CREATION = PersistentConfig(
|
||||
"ENABLE_OAUTH_GROUP_CREATION",
|
||||
"oauth.enable_group_creation",
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_CREATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||
"OAUTH_ROLES_CLAIM",
|
||||
"oauth.roles_claim",
|
||||
@@ -952,10 +958,15 @@ DEFAULT_MODELS = PersistentConfig(
|
||||
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
|
||||
)
|
||||
|
||||
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
|
||||
"DEFAULT_PROMPT_SUGGESTIONS",
|
||||
"ui.prompt_suggestions",
|
||||
[
|
||||
try:
|
||||
default_prompt_suggestions = json.loads(
|
||||
os.environ.get("DEFAULT_PROMPT_SUGGESTIONS", "[]")
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error loading DEFAULT_PROMPT_SUGGESTIONS: {e}")
|
||||
default_prompt_suggestions = []
|
||||
if default_prompt_suggestions == []:
|
||||
default_prompt_suggestions = [
|
||||
{
|
||||
"title": ["Help me study", "vocabulary for a college entrance exam"],
|
||||
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
||||
@@ -983,7 +994,11 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
|
||||
"title": ["Overcome procrastination", "give me tips"],
|
||||
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
|
||||
},
|
||||
],
|
||||
]
|
||||
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
|
||||
"DEFAULT_PROMPT_SUGGESTIONS",
|
||||
"ui.prompt_suggestions",
|
||||
default_prompt_suggestions,
|
||||
)
|
||||
|
||||
MODEL_ORDER_LIST = PersistentConfig(
|
||||
@@ -1062,6 +1077,14 @@ USER_PERMISSIONS_CHAT_EDIT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_SHARE = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_SHARE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_EXPORT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_EXPORT", "True").lower() == "true"
|
||||
)
|
||||
|
||||
USER_PERMISSIONS_CHAT_STT = (
|
||||
os.environ.get("USER_PERMISSIONS_CHAT_STT", "True").lower() == "true"
|
||||
)
|
||||
@@ -1126,6 +1149,8 @@ DEFAULT_USER_PERMISSIONS = {
|
||||
"file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD,
|
||||
"delete": USER_PERMISSIONS_CHAT_DELETE,
|
||||
"edit": USER_PERMISSIONS_CHAT_EDIT,
|
||||
"share": USER_PERMISSIONS_CHAT_SHARE,
|
||||
"export": USER_PERMISSIONS_CHAT_EXPORT,
|
||||
"stt": USER_PERMISSIONS_CHAT_STT,
|
||||
"tts": USER_PERMISSIONS_CHAT_TTS,
|
||||
"call": USER_PERMISSIONS_CHAT_CALL,
|
||||
@@ -1153,6 +1178,11 @@ ENABLE_CHANNELS = PersistentConfig(
|
||||
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_NOTES = PersistentConfig(
|
||||
"ENABLE_NOTES",
|
||||
"notes.enable",
|
||||
os.environ.get("ENABLE_NOTES", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
|
||||
"ENABLE_EVALUATION_ARENA_MODELS",
|
||||
@@ -1203,6 +1233,9 @@ ENABLE_USER_WEBHOOKS = PersistentConfig(
|
||||
os.environ.get("ENABLE_USER_WEBHOOKS", "True").lower() == "true",
|
||||
)
|
||||
|
||||
# FastAPI / AnyIO settings
|
||||
THREAD_POOL_SIZE = int(os.getenv("THREAD_POOL_SIZE", "0"))
|
||||
|
||||
|
||||
def validate_cors_origins(origins):
|
||||
for origin in origins:
|
||||
@@ -1229,7 +1262,9 @@ def validate_cors_origin(origin):
|
||||
# To test CORS_ALLOW_ORIGIN locally, you can set something like
|
||||
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
|
||||
# in your .env file depending on your frontend port, 5173 in this case.
|
||||
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
||||
CORS_ALLOW_ORIGIN = os.environ.get(
|
||||
"CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080"
|
||||
).split(";")
|
||||
|
||||
if "*" in CORS_ALLOW_ORIGIN:
|
||||
log.warning(
|
||||
@@ -1693,6 +1728,9 @@ MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None)
|
||||
# Qdrant
|
||||
QDRANT_URI = os.environ.get("QDRANT_URI", None)
|
||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
||||
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
|
||||
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
|
||||
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
||||
@@ -1724,6 +1762,14 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||
)
|
||||
|
||||
# Pinecone
|
||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME", "open-webui-index")
|
||||
PINECONE_DIMENSION = int(os.getenv("PINECONE_DIMENSION", 1536)) # or 3072, 1024, 768
|
||||
PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine")
|
||||
PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure"
|
||||
|
||||
####################################
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
@@ -1760,6 +1806,13 @@ ONEDRIVE_CLIENT_ID = PersistentConfig(
|
||||
os.environ.get("ONEDRIVE_CLIENT_ID", ""),
|
||||
)
|
||||
|
||||
ONEDRIVE_SHAREPOINT_URL = PersistentConfig(
|
||||
"ONEDRIVE_SHAREPOINT_URL",
|
||||
"onedrive.sharepoint_url",
|
||||
os.environ.get("ONEDRIVE_SHAREPOINT_URL", ""),
|
||||
)
|
||||
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
@@ -2092,6 +2145,24 @@ SEARXNG_QUERY_URL = PersistentConfig(
|
||||
os.getenv("SEARXNG_QUERY_URL", ""),
|
||||
)
|
||||
|
||||
YACY_QUERY_URL = PersistentConfig(
|
||||
"YACY_QUERY_URL",
|
||||
"rag.web.search.yacy_query_url",
|
||||
os.getenv("YACY_QUERY_URL", ""),
|
||||
)
|
||||
|
||||
YACY_USERNAME = PersistentConfig(
|
||||
"YACY_USERNAME",
|
||||
"rag.web.search.yacy_username",
|
||||
os.getenv("YACY_USERNAME", ""),
|
||||
)
|
||||
|
||||
YACY_PASSWORD = PersistentConfig(
|
||||
"YACY_PASSWORD",
|
||||
"rag.web.search.yacy_password",
|
||||
os.getenv("YACY_PASSWORD", ""),
|
||||
)
|
||||
|
||||
GOOGLE_PSE_API_KEY = PersistentConfig(
|
||||
"GOOGLE_PSE_API_KEY",
|
||||
"rag.web.search.google_pse_api_key",
|
||||
@@ -2256,6 +2327,29 @@ FIRECRAWL_API_BASE_URL = PersistentConfig(
|
||||
os.environ.get("FIRECRAWL_API_BASE_URL", "https://api.firecrawl.dev"),
|
||||
)
|
||||
|
||||
EXTERNAL_WEB_SEARCH_URL = PersistentConfig(
|
||||
"EXTERNAL_WEB_SEARCH_URL",
|
||||
"rag.web.search.external_web_search_url",
|
||||
os.environ.get("EXTERNAL_WEB_SEARCH_URL", ""),
|
||||
)
|
||||
|
||||
EXTERNAL_WEB_SEARCH_API_KEY = PersistentConfig(
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY",
|
||||
"rag.web.search.external_web_search_api_key",
|
||||
os.environ.get("EXTERNAL_WEB_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
EXTERNAL_WEB_LOADER_URL = PersistentConfig(
|
||||
"EXTERNAL_WEB_LOADER_URL",
|
||||
"rag.web.loader.external_web_loader_url",
|
||||
os.environ.get("EXTERNAL_WEB_LOADER_URL", ""),
|
||||
)
|
||||
|
||||
EXTERNAL_WEB_LOADER_API_KEY = PersistentConfig(
|
||||
"EXTERNAL_WEB_LOADER_API_KEY",
|
||||
"rag.web.loader.external_web_loader_api_key",
|
||||
os.environ.get("EXTERNAL_WEB_LOADER_API_KEY", ""),
|
||||
)
|
||||
|
||||
####################################
|
||||
# Images
|
||||
@@ -2566,6 +2660,18 @@ AUDIO_STT_AZURE_LOCALES = PersistentConfig(
|
||||
os.getenv("AUDIO_STT_AZURE_LOCALES", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_AZURE_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_AZURE_BASE_URL",
|
||||
"audio.stt.azure.base_url",
|
||||
os.getenv("AUDIO_STT_AZURE_BASE_URL", ""),
|
||||
)
|
||||
|
||||
AUDIO_STT_AZURE_MAX_SPEAKERS = PersistentConfig(
|
||||
"AUDIO_STT_AZURE_MAX_SPEAKERS",
|
||||
"audio.stt.azure.max_speakers",
|
||||
os.getenv("AUDIO_STT_AZURE_MAX_SPEAKERS", "3"),
|
||||
)
|
||||
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_TTS_OPENAI_API_BASE_URL",
|
||||
"audio.tts.openai.api_base_url",
|
||||
|
||||
@@ -354,6 +354,10 @@ BYPASS_MODEL_ACCESS_CONTROL = (
|
||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||
)
|
||||
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL = os.environ.get(
|
||||
"WEBUI_AUTH_SIGNOUT_REDIRECT_URL", None
|
||||
)
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
@@ -409,6 +413,11 @@ else:
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_SESSION_SSL = (
|
||||
os.environ.get("AIOHTTP_CLIENT_SESSION_SSL", "True").lower() == "true"
|
||||
)
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST",
|
||||
os.environ.get("AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "10"),
|
||||
@@ -437,6 +446,56 @@ else:
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
|
||||
os.environ.get("AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL", "True").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# SENTENCE TRANSFORMERS
|
||||
####################################
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_BACKEND = os.environ.get("SENTENCE_TRANSFORMERS_BACKEND", "")
|
||||
if SENTENCE_TRANSFORMERS_BACKEND == "":
|
||||
SENTENCE_TRANSFORMERS_BACKEND = "torch"
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_MODEL_KWARGS", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_MODEL_KWARGS == "":
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
|
||||
else:
|
||||
try:
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = json.loads(
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS
|
||||
)
|
||||
except Exception:
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS = None
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == "":
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = "torch"
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.environ.get(
|
||||
"SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS", ""
|
||||
)
|
||||
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == "":
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
|
||||
else:
|
||||
try:
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = json.loads(
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS
|
||||
)
|
||||
except Exception:
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
####################################
|
||||
@@ -446,6 +505,7 @@ OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true"
|
||||
if OFFLINE_MODE:
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
|
||||
####################################
|
||||
# AUDIT LOGGING
|
||||
####################################
|
||||
@@ -467,6 +527,7 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
||||
|
||||
####################################
|
||||
# OPENTELEMETRY
|
||||
####################################
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import text
|
||||
from typing import Optional
|
||||
from aiocache import cached
|
||||
import aiohttp
|
||||
import anyio.to_thread
|
||||
import requests
|
||||
|
||||
|
||||
@@ -100,11 +101,14 @@ from open_webui.config import (
|
||||
# OpenAI
|
||||
ENABLE_OPENAI_API,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
ONEDRIVE_SHAREPOINT_URL,
|
||||
OPENAI_API_BASE_URLS,
|
||||
OPENAI_API_KEYS,
|
||||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Thread pool size for FastAPI/AnyIO
|
||||
THREAD_POOL_SIZE,
|
||||
# Tool Server Configs
|
||||
TOOL_SERVER_CONNECTIONS,
|
||||
# Code Execution
|
||||
@@ -151,6 +155,8 @@ from open_webui.config import (
|
||||
AUDIO_STT_AZURE_API_KEY,
|
||||
AUDIO_STT_AZURE_REGION,
|
||||
AUDIO_STT_AZURE_LOCALES,
|
||||
AUDIO_STT_AZURE_BASE_URL,
|
||||
AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
AUDIO_TTS_API_KEY,
|
||||
AUDIO_TTS_ENGINE,
|
||||
AUDIO_TTS_MODEL,
|
||||
@@ -219,6 +225,9 @@ from open_webui.config import (
|
||||
SERPAPI_API_KEY,
|
||||
SERPAPI_ENGINE,
|
||||
SEARXNG_QUERY_URL,
|
||||
YACY_QUERY_URL,
|
||||
YACY_USERNAME,
|
||||
YACY_PASSWORD,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
@@ -240,12 +249,17 @@ from open_webui.config import (
|
||||
GOOGLE_DRIVE_CLIENT_ID,
|
||||
GOOGLE_DRIVE_API_KEY,
|
||||
ONEDRIVE_CLIENT_ID,
|
||||
ONEDRIVE_SHAREPOINT_URL,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
ENABLE_ONEDRIVE_INTEGRATION,
|
||||
UPLOAD_DIR,
|
||||
EXTERNAL_WEB_SEARCH_URL,
|
||||
EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
EXTERNAL_WEB_LOADER_URL,
|
||||
EXTERNAL_WEB_LOADER_API_KEY,
|
||||
# WebUI
|
||||
WEBUI_AUTH,
|
||||
WEBUI_NAME,
|
||||
@@ -260,6 +274,7 @@ from open_webui.config import (
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
API_KEY_ALLOWED_ENDPOINTS,
|
||||
ENABLE_CHANNELS,
|
||||
ENABLE_NOTES,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
ENABLE_MESSAGE_RATING,
|
||||
ENABLE_USER_WEBHOOKS,
|
||||
@@ -341,6 +356,7 @@ from open_webui.env import (
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
RESET_CONFIG_ON_START,
|
||||
@@ -370,6 +386,7 @@ from open_webui.utils.auth import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||
from open_webui.utils.oauth import OAuthManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -432,7 +449,18 @@ async def lifespan(app: FastAPI):
|
||||
if LICENSE_KEY:
|
||||
get_license_data(app, LICENSE_KEY)
|
||||
|
||||
# This should be blocking (sync) so functions are not deactivated on first /get_models calls
|
||||
# when the first user lands on the / route.
|
||||
log.info("Installing external dependencies of functions and tools...")
|
||||
install_tool_and_function_dependencies()
|
||||
|
||||
pool_size = THREAD_POOL_SIZE
|
||||
if pool_size and pool_size > 0:
|
||||
limiter = anyio.to_thread.current_default_thread_limiter()
|
||||
limiter.total_tokens = pool_size
|
||||
|
||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -543,6 +571,7 @@ app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
|
||||
|
||||
|
||||
app.state.config.ENABLE_CHANNELS = ENABLE_CHANNELS
|
||||
app.state.config.ENABLE_NOTES = ENABLE_NOTES
|
||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
|
||||
app.state.config.ENABLE_USER_WEBHOOKS = ENABLE_USER_WEBHOOKS
|
||||
@@ -576,6 +605,7 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
||||
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||
app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL
|
||||
app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
|
||||
|
||||
app.state.USER_COUNT = None
|
||||
@@ -646,6 +676,9 @@ app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.YACY_QUERY_URL = YACY_QUERY_URL
|
||||
app.state.config.YACY_USERNAME = YACY_USERNAME
|
||||
app.state.config.YACY_PASSWORD = YACY_PASSWORD
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||
@@ -668,6 +701,10 @@ app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
||||
app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = EXTERNAL_WEB_SEARCH_API_KEY
|
||||
app.state.config.EXTERNAL_WEB_LOADER_URL = EXTERNAL_WEB_LOADER_URL
|
||||
app.state.config.EXTERNAL_WEB_LOADER_API_KEY = EXTERNAL_WEB_LOADER_API_KEY
|
||||
|
||||
|
||||
app.state.config.PLAYWRIGHT_WS_URL = PLAYWRIGHT_WS_URL
|
||||
@@ -796,6 +833,8 @@ app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
|
||||
app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
|
||||
app.state.config.AUDIO_STT_AZURE_REGION = AUDIO_STT_AZURE_REGION
|
||||
app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
|
||||
app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL
|
||||
app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
@@ -869,7 +908,8 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Check for the specific watch path and the presence of 'v' parameter
|
||||
if path.endswith("/watch") and "v" in query_params:
|
||||
video_id = query_params["v"][0] # Extract the first 'v' parameter
|
||||
# Extract the first 'v' parameter
|
||||
video_id = query_params["v"][0]
|
||||
encoded_video_id = urlencode({"youtube": video_id})
|
||||
redirect_url = f"/?{encoded_video_id}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
@@ -1283,6 +1323,7 @@ async def get_app_config(request: Request):
|
||||
{
|
||||
"enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"enable_channels": app.state.config.ENABLE_CHANNELS,
|
||||
"enable_notes": app.state.config.ENABLE_NOTES,
|
||||
"enable_web_search": app.state.config.ENABLE_WEB_SEARCH,
|
||||
"enable_code_execution": app.state.config.ENABLE_CODE_EXECUTION,
|
||||
"enable_code_interpreter": app.state.config.ENABLE_CODE_INTERPRETER,
|
||||
@@ -1327,7 +1368,10 @@ async def get_app_config(request: Request):
|
||||
"client_id": GOOGLE_DRIVE_CLIENT_ID.value,
|
||||
"api_key": GOOGLE_DRIVE_API_KEY.value,
|
||||
},
|
||||
"onedrive": {"client_id": ONEDRIVE_CLIENT_ID.value},
|
||||
"onedrive": {
|
||||
"client_id": ONEDRIVE_CLIENT_ID.value,
|
||||
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
|
||||
},
|
||||
"license_metadata": app.state.LICENSE_METADATA,
|
||||
**(
|
||||
{
|
||||
@@ -1439,7 +1483,7 @@ async def get_manifest_json():
|
||||
"start_url": "/",
|
||||
"display": "standalone",
|
||||
"background_color": "#343541",
|
||||
"orientation": "natural",
|
||||
"orientation": "any",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/static/logo.png",
|
||||
|
||||
@@ -10,6 +10,8 @@ from open_webui.models.groups import Groups
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import or_
|
||||
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
@@ -67,6 +69,11 @@ class UserModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserModel]
|
||||
total: int
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@@ -160,11 +167,63 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
def get_users(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> list[UserModel]:
|
||||
self,
|
||||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> UserListResponse:
|
||||
with get_db() as db:
|
||||
query = db.query(User)
|
||||
|
||||
query = db.query(User).order_by(User.created_at.desc())
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
)
|
||||
)
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.name.asc())
|
||||
else:
|
||||
query = query.order_by(User.name.desc())
|
||||
elif order_by == "email":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.email.asc())
|
||||
else:
|
||||
query = query.order_by(User.email.desc())
|
||||
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
elif order_by == "last_active_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.last_active_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.last_active_at.desc())
|
||||
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.updated_at.desc())
|
||||
elif order_by == "role":
|
||||
if direction == "asc":
|
||||
query = query.order_by(User.role.asc())
|
||||
else:
|
||||
query = query.order_by(User.role.desc())
|
||||
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
@@ -172,8 +231,10 @@ class UsersTable:
|
||||
query = query.limit(limit)
|
||||
|
||||
users = query.all()
|
||||
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
return {
|
||||
"users": [UserModel.model_validate(user) for user in users],
|
||||
"total": db.query(User).count(),
|
||||
}
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
|
||||
53
backend/open_webui/retrieval/loaders/external.py
Normal file
53
backend/open_webui/retrieval/loaders/external.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
external_url: str,
|
||||
external_api_key: str,
|
||||
continue_on_failure: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.external_url = external_url
|
||||
self.external_api_key = external_api_key
|
||||
self.urls = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
self.continue_on_failure = continue_on_failure
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
batch_size = 20
|
||||
for i in range(0, len(self.urls), batch_size):
|
||||
urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
},
|
||||
json={
|
||||
"urls": urls,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
for result in results:
|
||||
yield Document(
|
||||
page_content=result.get("page_content", ""),
|
||||
metadata=result.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
||||
else:
|
||||
raise e
|
||||
@@ -207,7 +207,7 @@ def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
|
||||
|
||||
for distance, document, metadata in zip(distances, documents, metadatas):
|
||||
if isinstance(document, str):
|
||||
doc_hash = hashlib.md5(
|
||||
doc_hash = hashlib.sha256(
|
||||
document.encode()
|
||||
).hexdigest() # Compute a hash for uniqueness
|
||||
|
||||
@@ -260,23 +260,47 @@ def query_collection(
|
||||
k: int,
|
||||
) -> dict:
|
||||
results = []
|
||||
for query in queries:
|
||||
log.debug(f"query_collection:query {query}")
|
||||
query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
for collection_name in collection_names:
|
||||
error = False
|
||||
|
||||
def process_query_collection(collection_name, query_embedding):
|
||||
try:
|
||||
if collection_name:
|
||||
try:
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
results.append(result.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
else:
|
||||
pass
|
||||
result = query_doc(
|
||||
collection_name=collection_name,
|
||||
k=k,
|
||||
query_embedding=query_embedding,
|
||||
)
|
||||
if result is not None:
|
||||
return result.model_dump(), None
|
||||
return None, None
|
||||
except Exception as e:
|
||||
log.exception(f"Error when querying the collection: {e}")
|
||||
return None, e
|
||||
|
||||
# Generate all query embeddings (in one call)
|
||||
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
||||
log.debug(
|
||||
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_results = []
|
||||
for query_embedding in query_embeddings:
|
||||
for collection_name in collection_names:
|
||||
result = executor.submit(
|
||||
process_query_collection, collection_name, query_embedding
|
||||
)
|
||||
future_results.append(result)
|
||||
task_results = [future.result() for future in future_results]
|
||||
|
||||
for result, err in task_results:
|
||||
if err is not None:
|
||||
error = True
|
||||
elif result is not None:
|
||||
results.append(result)
|
||||
|
||||
if error and not results:
|
||||
log.warning("All collection queries failed. No results returned.")
|
||||
|
||||
return merge_and_sort_query_results(results, k=k)
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ elif VECTOR_DB == "elasticsearch":
|
||||
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = ElasticsearchClient()
|
||||
elif VECTOR_DB == "pinecone":
|
||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||
|
||||
VECTOR_DB_CLIENT = PineconeClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
|
||||
@@ -5,7 +5,12 @@ from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
@@ -23,7 +28,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ChromaClient:
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
|
||||
@@ -2,7 +2,12 @@ from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
@@ -15,7 +20,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class ElasticsearchClient:
|
||||
class ElasticsearchClient(VectorDBBase):
|
||||
"""
|
||||
Important:
|
||||
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||
|
||||
@@ -4,7 +4,12 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
@@ -16,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
if MILVUS_TOKEN is None:
|
||||
|
||||
@@ -2,7 +2,12 @@ from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
@@ -12,7 +17,7 @@ from open_webui.config import (
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.client = OpenSearch(
|
||||
|
||||
@@ -22,7 +22,12 @@ from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -44,7 +49,7 @@ class DocumentChunk(Base):
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient:
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
@@ -136,9 +141,8 @@ class PgvectorClient:
|
||||
# Pad the vector with zeros
|
||||
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
||||
elif current_length > VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector length {current_length} not supported. Max length must be <= {VECTOR_LENGTH}"
|
||||
)
|
||||
# Truncate the vector to VECTOR_LENGTH
|
||||
vector = vector[:VECTOR_LENGTH]
|
||||
return vector
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
|
||||
412
backend/open_webui/retrieval/vector/dbs/pinecone.py
Normal file
412
backend/open_webui/retrieval/vector/dbs/pinecone.py
Normal file
@@ -0,0 +1,412 @@
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_ENVIRONMENT,
|
||||
PINECONE_INDEX_NAME,
|
||||
PINECONE_DIMENSION,
|
||||
PINECONE_METRIC,
|
||||
PINECONE_CLOUD,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
||||
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class PineconeClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
|
||||
# Validate required configuration
|
||||
self._validate_config()
|
||||
|
||||
# Store configuration values
|
||||
self.api_key = PINECONE_API_KEY
|
||||
self.environment = PINECONE_ENVIRONMENT
|
||||
self.index_name = PINECONE_INDEX_NAME
|
||||
self.dimension = PINECONE_DIMENSION
|
||||
self.metric = PINECONE_METRIC
|
||||
self.cloud = PINECONE_CLOUD
|
||||
|
||||
# Initialize Pinecone client
|
||||
self.client = Pinecone(api_key=self.api_key)
|
||||
|
||||
# Create index if it doesn't exist
|
||||
self._initialize_index()
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""Validate that all required configuration variables are set."""
|
||||
missing_vars = []
|
||||
if not PINECONE_API_KEY:
|
||||
missing_vars.append("PINECONE_API_KEY")
|
||||
if not PINECONE_ENVIRONMENT:
|
||||
missing_vars.append("PINECONE_ENVIRONMENT")
|
||||
if not PINECONE_INDEX_NAME:
|
||||
missing_vars.append("PINECONE_INDEX_NAME")
|
||||
if not PINECONE_DIMENSION:
|
||||
missing_vars.append("PINECONE_DIMENSION")
|
||||
if not PINECONE_CLOUD:
|
||||
missing_vars.append("PINECONE_CLOUD")
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Required configuration missing: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
def _initialize_index(self) -> None:
|
||||
"""Initialize the Pinecone index."""
|
||||
try:
|
||||
# Check if index exists
|
||||
if self.index_name not in self.client.list_indexes().names():
|
||||
log.info(f"Creating Pinecone index '{self.index_name}'...")
|
||||
self.client.create_index(
|
||||
name=self.index_name,
|
||||
dimension=self.dimension,
|
||||
metric=self.metric,
|
||||
spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
|
||||
)
|
||||
log.info(f"Successfully created Pinecone index '{self.index_name}'")
|
||||
else:
|
||||
log.info(f"Using existing Pinecone index '{self.index_name}'")
|
||||
|
||||
# Connect to the index
|
||||
self.index = self.client.Index(self.index_name)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize Pinecone index: {e}")
|
||||
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], collection_name_with_prefix: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert VectorItem objects to Pinecone point format."""
|
||||
points = []
|
||||
for item in items:
|
||||
# Start with any existing metadata or an empty dict
|
||||
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
|
||||
|
||||
# Add text to metadata if available
|
||||
if "text" in item:
|
||||
metadata["text"] = item["text"]
|
||||
|
||||
# Always add collection_name to metadata for filtering
|
||||
metadata["collection_name"] = collection_name_with_prefix
|
||||
|
||||
point = {
|
||||
"id": item["id"],
|
||||
"values": item["vector"],
|
||||
"metadata": metadata,
|
||||
}
|
||||
points.append(point)
|
||||
return points
|
||||
|
||||
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
|
||||
"""Get the collection name with prefix."""
|
||||
return f"{self.collection_prefix}_{collection_name}"
|
||||
|
||||
def _normalize_distance(self, score: float) -> float:
|
||||
"""Normalize distance score based on the metric used."""
|
||||
if self.metric.lower() == "cosine":
|
||||
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
|
||||
return (score + 1.0) / 2.0
|
||||
elif self.metric.lower() in ["euclidean", "dotproduct"]:
|
||||
# These are already suitable for ranking (smaller is better for Euclidean)
|
||||
return score
|
||||
else:
|
||||
# For other metrics, use as is
|
||||
return score
|
||||
|
||||
def _result_to_get_result(self, matches: list) -> GetResult:
|
||||
"""Convert Pinecone matches to GetResult format."""
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for match in matches:
|
||||
metadata = match.get("metadata", {})
|
||||
ids.append(match["id"])
|
||||
documents.append(metadata.get("text", ""))
|
||||
metadatas.append(metadata)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if a collection exists by searching for at least one item."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
# Search for at least 1 item with this collection name in metadata
|
||||
response = self.index.query(
|
||||
vector=[0.0] * self.dimension, # dummy vector
|
||||
top_k=1,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
include_metadata=False,
|
||||
)
|
||||
return len(response.matches) > 0
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error checking collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection by removing all vectors with the collection name in metadata."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
try:
|
||||
self.index.delete(filter={"collection_name": collection_name_with_prefix})
|
||||
log.info(
|
||||
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Insert in batches for better performance and reliability
|
||||
for i in range(0, len(points), BATCH_SIZE):
|
||||
batch = points[i : i + BATCH_SIZE]
|
||||
try:
|
||||
self.index.upsert(vectors=batch)
|
||||
log.debug(
|
||||
f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error inserting batch into '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
log.info(
|
||||
f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Upsert (insert or update) vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Upsert in batches
|
||||
for i in range(0, len(points), BATCH_SIZE):
|
||||
batch = points[i : i + BATCH_SIZE]
|
||||
try:
|
||||
self.index.upsert(vectors=batch)
|
||||
log.debug(
|
||||
f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error upserting batch into '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
log.info(
|
||||
f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
if not vectors or not vectors[0]:
|
||||
log.warning("No vectors provided for search")
|
||||
return None
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
|
||||
try:
|
||||
# Search using the first vector (assuming this is the intended behavior)
|
||||
query_vector = vectors[0]
|
||||
|
||||
# Perform the search
|
||||
query_response = self.index.query(
|
||||
vector=query_vector,
|
||||
top_k=limit,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
if not query_response.matches:
|
||||
# Return empty result if no matches
|
||||
return SearchResult(
|
||||
ids=[[]],
|
||||
documents=[[]],
|
||||
metadatas=[[]],
|
||||
distances=[[]],
|
||||
)
|
||||
|
||||
# Convert to GetResult format
|
||||
get_result = self._result_to_get_result(query_response.matches)
|
||||
|
||||
# Calculate normalized distances based on metric
|
||||
distances = [
|
||||
[
|
||||
self._normalize_distance(match.score)
|
||||
for match in query_response.matches
|
||||
]
|
||||
]
|
||||
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=distances,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""Query vectors by metadata filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
|
||||
try:
|
||||
# Create a zero vector for the dimension as Pinecone requires a vector
|
||||
zero_vector = [0.0] * self.dimension
|
||||
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
|
||||
# Perform metadata-only query
|
||||
query_response = self.index.query(
|
||||
vector=zero_vector,
|
||||
filter=pinecone_filter,
|
||||
top_k=limit,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Get all vectors in a collection."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
# Use a zero vector for fetching all entries
|
||||
zero_vector = [0.0] * self.dimension
|
||||
|
||||
# Add filter to only get vectors for this collection
|
||||
query_response = self.index.query(
|
||||
vector=zero_vector,
|
||||
top_k=NO_LIMIT,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by IDs or filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
|
||||
try:
|
||||
if ids:
|
||||
# Delete by IDs (in batches for large deletions)
|
||||
for i in range(0, len(ids), BATCH_SIZE):
|
||||
batch_ids = ids[i : i + BATCH_SIZE]
|
||||
# Note: When deleting by ID, we can't filter by collection_name
|
||||
# This is a limitation of Pinecone - be careful with ID uniqueness
|
||||
self.index.delete(ids=batch_ids)
|
||||
log.debug(
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(
|
||||
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
elif filter:
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
# Delete by metadata filter
|
||||
self.index.delete(filter=pinecone_filter)
|
||||
log.info(
|
||||
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
else:
|
||||
log.warning("No ids or filter provided for delete operation")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting from collection '{collection_name}': {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the database by deleting all collections."""
|
||||
try:
|
||||
self.index.delete(delete_all=True)
|
||||
log.info("All vectors successfully deleted from the index.")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to reset Pinecone index: {e}")
|
||||
raise
|
||||
@@ -1,12 +1,24 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
QDRANT_URI,
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_PREFER_GRPC,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
@@ -15,16 +27,34 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.client = (
|
||||
Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
if self.QDRANT_URI
|
||||
else None
|
||||
)
|
||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
self.client = None
|
||||
return
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
host = parsed.hostname or self.QDRANT_URI
|
||||
http_port = parsed.port or 6333 # default REST port
|
||||
|
||||
if self.PREFER_GRPC:
|
||||
self.client = Qclient(
|
||||
host=host,
|
||||
port=http_port,
|
||||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
)
|
||||
else:
|
||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
@@ -50,7 +80,9 @@ class QdrantClient:
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension, distance=models.Distance.COSINE
|
||||
size=dimension,
|
||||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
@@ -17,3 +18,69 @@ class GetResult(BaseModel):
|
||||
|
||||
class SearchResult(GetResult):
|
||||
distances: Optional[List[List[float | int]]]
|
||||
|
||||
|
||||
class VectorDBBase(ABC):
|
||||
"""
|
||||
Abstract base class for all vector database backends.
|
||||
|
||||
Implementations of this class provide methods for collection management,
|
||||
vector insertion, deletion, similarity search, and metadata filtering.
|
||||
|
||||
Any custom vector database integration must inherit from this class and
|
||||
implement all abstract methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if the collection exists in the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection from the vector DB."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert a list of vector items into a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert or update vector items in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
"""Query vectors from a collection using metadata filter."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Retrieve all vectors from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by ID or filter from a collection."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by removing all collections or those matching a condition."""
|
||||
pass
|
||||
|
||||
47
backend/open_webui/retrieval/web/external.py
Normal file
47
backend/open_webui/retrieval/web/external.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_external(
|
||||
external_url: str,
|
||||
external_api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
response = requests.post(
|
||||
external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {external_api_key}",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"count": count,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("link"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
return []
|
||||
49
backend/open_webui/retrieval/web/firecrawl.py
Normal file
49
backend/open_webui/retrieval/web/firecrawl.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_firecrawl(
|
||||
firecrawl_url: str,
|
||||
firecrawl_api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
firecrawl_search_url = urljoin(firecrawl_url, "/v1/search")
|
||||
response = requests.post(
|
||||
firecrawl_search_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {firecrawl_api_key}",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"limit": count,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json().get("data", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("url"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
return []
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -21,18 +21,25 @@ def search_tavily(
|
||||
Args:
|
||||
api_key (str): A Tavily Search API key
|
||||
query (str): The query to search for
|
||||
count (int): The maximum number of results to return
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
response = requests.post(url, json=data)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
data = {"query": query, "max_results": count}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
raw_search_results = json_response.get("results", [])
|
||||
results = json_response.get("results", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
@@ -40,5 +47,5 @@ def search_tavily(
|
||||
title=result.get("title", ""),
|
||||
snippet=result.get("content"),
|
||||
)
|
||||
for result in raw_search_results[:count]
|
||||
for result in results
|
||||
]
|
||||
|
||||
@@ -25,6 +25,7 @@ from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.retrieval.loaders.external import ExternalLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
@@ -35,6 +36,8 @@ from open_webui.config import (
|
||||
FIRECRAWL_API_KEY,
|
||||
TAVILY_API_KEY,
|
||||
TAVILY_EXTRACT_DEPTH,
|
||||
EXTERNAL_WEB_LOADER_URL,
|
||||
EXTERNAL_WEB_LOADER_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -167,7 +170,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
continue_on_failure: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
mode: Literal["crawl", "scrape", "map"] = "crawl",
|
||||
mode: Literal["crawl", "scrape", "map"] = "scrape",
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
@@ -225,7 +228,10 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
mode=self.mode,
|
||||
params=self.params,
|
||||
)
|
||||
yield from loader.lazy_load()
|
||||
for document in loader.lazy_load():
|
||||
if not document.metadata.get("source"):
|
||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
@@ -245,6 +251,8 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
params=self.params,
|
||||
)
|
||||
async for document in loader.alazy_load():
|
||||
if not document.metadata.get("source"):
|
||||
document.metadata["source"] = document.metadata.get("sourceURL")
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
@@ -619,6 +627,11 @@ def get_web_loader(
|
||||
web_loader_args["api_key"] = TAVILY_API_KEY.value
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "external":
|
||||
WebLoaderClass = ExternalLoader
|
||||
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
|
||||
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
|
||||
|
||||
if WebLoaderClass:
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
|
||||
85
backend/open_webui/retrieval/web/yacy.py
Normal file
85
backend/open_webui/retrieval/web/yacy.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_yacy(
|
||||
query_url: str,
|
||||
username: Optional[str],
|
||||
password: Optional[str],
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search a Yacy instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function accepts username and password for authenticating to Yacy.
|
||||
|
||||
Args:
|
||||
query_url (str): The base URL of the Yacy server.
|
||||
username (str): Optional YaCy username.
|
||||
password (str): Optional YaCy password.
|
||||
query (str): The search term or question to find in the Yacy database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
|
||||
# Use authentication if either username or password is set
|
||||
yacy_auth = None
|
||||
if username or password:
|
||||
yacy_auth = HTTPDigestAuth(username, password)
|
||||
|
||||
params = {
|
||||
"query": query,
|
||||
"contentdom": "text",
|
||||
"resource": "global",
|
||||
"maximumRecords": count,
|
||||
"nav": "none",
|
||||
}
|
||||
|
||||
# Check if provided a json API URL
|
||||
if not query_url.endswith("yacysearch.json"):
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.rstrip('/') + "/yacysearch.json"
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
auth=yacy_auth,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("channels", [{}])[0].get("items", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result.get("title"), snippet=result.get("description")
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
@@ -150,7 +150,8 @@ class STTConfigForm(BaseModel):
|
||||
AZURE_API_KEY: str
|
||||
AZURE_REGION: str
|
||||
AZURE_LOCALES: str
|
||||
|
||||
AZURE_BASE_URL: str
|
||||
AZURE_MAX_SPEAKERS: str
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
tts: TTSConfigForm
|
||||
@@ -181,6 +182,8 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -210,6 +213,8 @@ async def update_audio_config(
|
||||
request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
|
||||
request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
|
||||
request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
|
||||
request.app.state.config.AUDIO_STT_AZURE_BASE_URL = form_data.stt.AZURE_BASE_URL
|
||||
request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = form_data.stt.AZURE_MAX_SPEAKERS
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -238,6 +243,8 @@ async def update_audio_config(
|
||||
"AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
|
||||
"AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
|
||||
"AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
|
||||
"AZURE_BASE_URL": request.app.state.config.AUDIO_STT_AZURE_BASE_URL,
|
||||
"AZURE_MAX_SPEAKERS": request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -641,6 +648,8 @@ def transcribe(request: Request, file_path):
|
||||
api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
|
||||
region = request.app.state.config.AUDIO_STT_AZURE_REGION
|
||||
locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
|
||||
base_url = request.app.state.config.AUDIO_STT_AZURE_BASE_URL
|
||||
max_speakers = request.app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS
|
||||
|
||||
# IF NO LOCALES, USE DEFAULTS
|
||||
if len(locales) < 2:
|
||||
@@ -664,7 +673,13 @@ def transcribe(request: Request, file_path):
|
||||
if not api_key or not region:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Azure API key and region are required for Azure STT",
|
||||
detail="Azure API key is required for Azure STT",
|
||||
)
|
||||
|
||||
if not base_url and not region:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Azure region or base url is required for Azure STT",
|
||||
)
|
||||
|
||||
r = None
|
||||
@@ -674,13 +689,14 @@ def transcribe(request: Request, file_path):
|
||||
"definition": json.dumps(
|
||||
{
|
||||
"locales": locales.split(","),
|
||||
"diarization": {"maxSpeakers": 3, "enabled": True},
|
||||
"diarization": {"maxSpeakers": max_speakers, "enabled": True},
|
||||
}
|
||||
if locales
|
||||
else {}
|
||||
)
|
||||
}
|
||||
url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
|
||||
|
||||
url = base_url or f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
|
||||
|
||||
# Use context manager to ensure file is properly closed
|
||||
with open(file_path, "rb") as audio_file:
|
||||
|
||||
@@ -27,20 +27,24 @@ from open_webui.env import (
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.auth import (
|
||||
decode_token,
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_password_hash,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.access_control import get_permissions
|
||||
@@ -72,27 +76,29 @@ class SessionUserResponse(Token, UserResponse):
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_token = get_http_authorization_cred(auth_header)
|
||||
token = auth_token.credentials
|
||||
data = decode_token(token)
|
||||
|
||||
datetime_expires_at = (
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
)
|
||||
expires_at = data.get("exp")
|
||||
|
||||
if (expires_at is not None) and int(time.time()) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=datetime_expires_at,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(expires_at, datetime.timezone.utc)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
@@ -288,18 +294,30 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
user = Auths.authenticate_user_by_trusted_header(email)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
expires_at = None
|
||||
if expires_delta:
|
||||
expires_at = int(time.time()) + int(expires_delta.total_seconds())
|
||||
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(
|
||||
request.app.state.config.JWT_EXPIRES_IN
|
||||
),
|
||||
expires_delta=expires_delta,
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
expires=(
|
||||
datetime.datetime.fromtimestamp(
|
||||
expires_at, datetime.timezone.utc
|
||||
)
|
||||
if expires_at
|
||||
else None
|
||||
),
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
samesite=WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
@@ -309,6 +327,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
@@ -566,6 +585,12 @@ async def signout(request: Request, response: Response):
|
||||
detail="Failed to sign out from the OpenID provider.",
|
||||
)
|
||||
|
||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||
return RedirectResponse(
|
||||
headers=response.headers,
|
||||
url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@@ -664,6 +689,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
}
|
||||
|
||||
@@ -680,6 +706,7 @@ class AdminConfig(BaseModel):
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
ENABLE_MESSAGE_RATING: bool
|
||||
ENABLE_CHANNELS: bool
|
||||
ENABLE_NOTES: bool
|
||||
ENABLE_USER_WEBHOOKS: bool
|
||||
|
||||
|
||||
@@ -700,6 +727,7 @@ async def update_admin_config(
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||
request.app.state.config.ENABLE_NOTES = form_data.ENABLE_NOTES
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
@@ -724,11 +752,12 @@ async def update_admin_config(
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
}
|
||||
|
||||
|
||||
@@ -638,8 +638,17 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
|
||||
@router.post("/{id}/share", response_model=Optional[ChatResponse])
|
||||
async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
|
||||
if not has_permission(
|
||||
user.id, "chat.share", request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
|
||||
if chat:
|
||||
if chat.share_id:
|
||||
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
|
||||
|
||||
@@ -19,6 +19,8 @@ from fastapi import (
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
@@ -83,10 +85,12 @@ def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = {},
|
||||
file_metadata: dict = None,
|
||||
process: bool = Query(True),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
file_metadata = file_metadata if file_metadata else {}
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
@@ -95,7 +99,13 @@ def upload_file(
|
||||
id = str(uuid.uuid4())
|
||||
name = filename
|
||||
filename = f"{id}_{filename}"
|
||||
contents, file_path = Storage.upload_file(file.file, filename)
|
||||
tags = {
|
||||
"OpenWebUI-User-Email": user.email,
|
||||
"OpenWebUI-User-Id": user.id,
|
||||
"OpenWebUI-User-Name": user.name,
|
||||
"OpenWebUI-File-Id": id,
|
||||
}
|
||||
contents, file_path = Storage.upload_file(file.file, filename, tags)
|
||||
|
||||
file_item = Files.insert_new_file(
|
||||
user.id,
|
||||
@@ -129,7 +139,15 @@ def upload_file(
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif file.content_type not in ["image/png", "image/jpeg", "image/gif"]:
|
||||
elif file.content_type not in [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"video/mp4",
|
||||
"video/ogg",
|
||||
"video/quicktime",
|
||||
"video/webm",
|
||||
]:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
@@ -173,7 +191,8 @@ async def list_files(user=Depends(get_verified_user), content: bool = Query(True
|
||||
|
||||
if not content:
|
||||
for file in files:
|
||||
del file.data["content"]
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return files
|
||||
|
||||
@@ -214,7 +233,8 @@ async def search_files(
|
||||
|
||||
if not content:
|
||||
for file in matching_files:
|
||||
del file.data["content"]
|
||||
if "content" in file.data:
|
||||
del file.data["content"]
|
||||
|
||||
return matching_files
|
||||
|
||||
@@ -431,6 +451,13 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
file_user = Users.get_user_by_id(file.user_id)
|
||||
if not file_user.role == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
file.user_id == user.id
|
||||
or user.role == "admin"
|
||||
|
||||
@@ -500,7 +500,11 @@ async def image_generations(
|
||||
if form_data.size
|
||||
else request.app.state.config.IMAGE_SIZE
|
||||
),
|
||||
"response_format": "b64_json",
|
||||
**(
|
||||
{"response_format": "b64_json"}
|
||||
if "gpt-image-1" in request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
|
||||
@@ -10,7 +10,7 @@ from open_webui.models.knowledge import (
|
||||
KnowledgeUserResponse,
|
||||
RAGConfigForm
|
||||
)
|
||||
from open_webui.models.files import Files, FileModel
|
||||
from open_webui.models.files import Files, FileModel, FileMetadataResponse
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
@@ -179,10 +179,26 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
|
||||
log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
|
||||
deleted_knowledge_bases = []
|
||||
|
||||
for knowledge_base in knowledge_bases:
|
||||
# -- Robust error handling for missing or invalid data
|
||||
if not knowledge_base.data or not isinstance(knowledge_base.data, dict):
|
||||
log.warning(
|
||||
f"Knowledge base {knowledge_base.id} has no data or invalid data ({knowledge_base.data!r}). Deleting."
|
||||
)
|
||||
try:
|
||||
Knowledges.delete_knowledge_by_id(id=knowledge_base.id)
|
||||
deleted_knowledge_bases.append(knowledge_base.id)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to delete invalid knowledge base {knowledge_base.id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
file_ids = knowledge_base.data.get("file_ids", [])
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
try:
|
||||
if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
|
||||
VECTOR_DB_CLIENT.delete_collection(
|
||||
@@ -190,10 +206,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error deleting vector DB collection",
|
||||
)
|
||||
continue # Skip, don't raise
|
||||
|
||||
failed_files = []
|
||||
for file in files:
|
||||
@@ -214,10 +227,8 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error processing knowledge base",
|
||||
)
|
||||
# Don't raise, just continue
|
||||
continue
|
||||
|
||||
if failed_files:
|
||||
log.warning(
|
||||
@@ -226,7 +237,9 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
for failed in failed_files:
|
||||
log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
|
||||
|
||||
log.info("Reindexing completed successfully")
|
||||
log.info(
|
||||
f"Reindexing completed. Deleted {len(deleted_knowledge_bases)} invalid knowledge bases: {deleted_knowledge_bases}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -236,7 +249,7 @@ async def reindex_knowledge_files(request: Request, user=Depends(get_verified_us
|
||||
|
||||
|
||||
class KnowledgeFilesResponse(KnowledgeResponse):
|
||||
files: list[FileModel]
|
||||
files: list[FileMetadataResponse]
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[KnowledgeFilesResponse])
|
||||
@@ -252,7 +265,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
):
|
||||
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -380,7 +393,7 @@ def add_file_to_knowledge_by_id(
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -457,7 +470,7 @@ def update_file_from_knowledge_by_id(
|
||||
data = knowledge.data or {}
|
||||
file_ids = data.get("file_ids", [])
|
||||
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -539,7 +552,7 @@ def remove_file_from_knowledge_by_id(
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
if knowledge:
|
||||
files = Files.get_files_by_ids(file_ids)
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
@@ -735,7 +748,7 @@ def add_files_to_knowledge_batch(
|
||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_files_by_ids(existing_file_ids),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
@@ -743,5 +756,6 @@ def add_files_to_knowledge_batch(
|
||||
)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_file_metadatas_by_ids(existing_file_ids),
|
||||
)
|
||||
|
||||
@@ -54,6 +54,7 @@ from open_webui.config import (
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
@@ -91,6 +92,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -141,6 +143,7 @@ async def send_post_request(
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -216,7 +219,8 @@ async def verify_connection(
|
||||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
@@ -234,6 +238,7 @@ async def verify_connection(
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
detail = f"HTTP Error: {r.status}"
|
||||
@@ -1006,7 +1011,7 @@ class GenerateCompletionForm(BaseModel):
|
||||
prompt: str
|
||||
suffix: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
format: Optional[str] = None
|
||||
format: Optional[Union[dict, str]] = None
|
||||
options: Optional[dict] = None
|
||||
system: Optional[str] = None
|
||||
template: Optional[str] = None
|
||||
@@ -1482,7 +1487,9 @@ async def download_file_stream(
|
||||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(file_url, headers=headers) as response:
|
||||
async with session.get(
|
||||
file_url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
with open(file_path, "ab+") as file:
|
||||
@@ -1497,7 +1504,8 @@ async def download_file_stream(
|
||||
|
||||
if done:
|
||||
file.seek(0)
|
||||
hashed = calculate_sha256(file)
|
||||
chunk_size = 1024 * 1024 * 2
|
||||
hashed = calculate_sha256(file, chunk_size)
|
||||
file.seek(0)
|
||||
|
||||
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
|
||||
|
||||
@@ -21,6 +21,7 @@ from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
)
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
@@ -74,6 +75,7 @@ async def send_get_request(url, key=None, user: UserModel = None):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -92,20 +94,19 @@ async def cleanup_response(
|
||||
await session.close()
|
||||
|
||||
|
||||
def openai_o1_o3_handler(payload):
|
||||
def openai_o_series_handler(payload):
|
||||
"""
|
||||
Handle o1, o3 specific parameters
|
||||
Handle "o" series specific parameters
|
||||
"""
|
||||
if "max_tokens" in payload:
|
||||
# Remove "max_tokens" from the payload
|
||||
# Convert "max_tokens" to "max_completion_tokens" for all o-series models
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Fix: o1 and o3 do not support the "system" role directly.
|
||||
# For older models like "o1-mini" or "o1-preview", use role "user".
|
||||
# For newer o1/o3 models, replace "system" with "developer".
|
||||
# Handle system role conversion based on model type
|
||||
if payload["messages"][0]["role"] == "system":
|
||||
model_lower = payload["model"].lower()
|
||||
# Legacy models use "user" role instead of "system"
|
||||
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
||||
payload["messages"][0]["role"] = "user"
|
||||
else:
|
||||
@@ -462,7 +463,8 @@ async def get_models(
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
@@ -481,6 +483,7 @@ async def get_models(
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
@@ -542,7 +545,8 @@ async def verify_connection(
|
||||
key = form_data.key
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
@@ -561,6 +565,7 @@ async def verify_connection(
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
@@ -666,10 +671,10 @@ async def generate_chat_completion(
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
# Fix: o1,o3 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1_o3 = payload["model"].lower().startswith(("o1", "o3-"))
|
||||
if is_o1_o3:
|
||||
payload = openai_o1_o3_handler(payload)
|
||||
# Check if model is from "o" series
|
||||
is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4"))
|
||||
if is_o_series:
|
||||
payload = openai_o_series_handler(payload)
|
||||
elif "api.openai.com" not in url:
|
||||
# Remove "max_completion_tokens" from the payload for backward compatibility
|
||||
if "max_completion_tokens" in payload:
|
||||
@@ -723,6 +728,7 @@ async def generate_chat_completion(
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
# Check if response is SSE
|
||||
@@ -802,6 +808,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
@@ -115,7 +115,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
|
||||
@@ -53,6 +53,7 @@ from open_webui.retrieval.web.jina_search import search_jina
|
||||
from open_webui.retrieval.web.searchapi import search_searchapi
|
||||
from open_webui.retrieval.web.serpapi import search_serpapi
|
||||
from open_webui.retrieval.web.searxng import search_searxng
|
||||
from open_webui.retrieval.web.yacy import search_yacy
|
||||
from open_webui.retrieval.web.serper import search_serper
|
||||
from open_webui.retrieval.web.serply import search_serply
|
||||
from open_webui.retrieval.web.serpstack import search_serpstack
|
||||
@@ -61,6 +62,8 @@ from open_webui.retrieval.web.bing import search_bing
|
||||
from open_webui.retrieval.web.exa import search_exa
|
||||
from open_webui.retrieval.web.perplexity import search_perplexity
|
||||
from open_webui.retrieval.web.sougou import search_sougou
|
||||
from open_webui.retrieval.web.firecrawl import search_firecrawl
|
||||
from open_webui.retrieval.web.external import search_external
|
||||
|
||||
from open_webui.retrieval.utils import (
|
||||
get_embedding_function,
|
||||
@@ -90,7 +93,12 @@ from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
DOCKER,
|
||||
SENTENCE_TRANSFORMERS_BACKEND,
|
||||
SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -117,6 +125,8 @@ def get_ef(
|
||||
get_model_path(embedding_model, auto_update),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
backend=SENTENCE_TRANSFORMERS_BACKEND,
|
||||
model_kwargs=SENTENCE_TRANSFORMERS_MODEL_KWARGS,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Error loading SentenceTransformer: {e}")
|
||||
@@ -150,6 +160,8 @@ def get_rf(
|
||||
get_model_path(reranking_model, auto_update),
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
||||
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"CrossEncoder: {e}")
|
||||
@@ -460,6 +472,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": rag_config.get("web_search_domain_filter_list", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": rag_config.get("bypass_web_search_embedding_and_retrieval", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
|
||||
"SEARXNG_QUERY_URL": rag_config.get("searxng_query_url", request.app.state.config.SEARXNG_QUERY_URL),
|
||||
"YACY_QUERY_URL": rag_config.get("yacy_query_url", request.app.state.config.YACY_QUERY_URL),
|
||||
"YACY_USERNAME": rag_config.get("yacy_query_username",request.app.state.config.YACY_USERNAME),
|
||||
"YACY_PASSWORD": rag_config.get("yacy_query_password",request.app.state.config.YACY_PASSWORD),
|
||||
"GOOGLE_PSE_API_KEY": rag_config.get("google_pse_api_key", request.app.state.config.GOOGLE_PSE_API_KEY),
|
||||
"GOOGLE_PSE_ENGINE_ID": rag_config.get("google_pse_engine_id", request.app.state.config.GOOGLE_PSE_ENGINE_ID),
|
||||
"BRAVE_SEARCH_API_KEY": rag_config.get("brave_search_api_key", request.app.state.config.BRAVE_SEARCH_API_KEY),
|
||||
@@ -489,6 +504,10 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
|
||||
"FIRECRAWL_API_KEY": rag_config.get("firecrawl_api_key", request.app.state.config.FIRECRAWL_API_KEY),
|
||||
"FIRECRAWL_API_BASE_URL": rag_config.get("firecrawl_api_base_url", request.app.state.config.FIRECRAWL_API_BASE_URL),
|
||||
"TAVILY_EXTRACT_DEPTH": rag_config.get("tavily_extract_depth", request.app.state.config.TAVILY_EXTRACT_DEPTH),
|
||||
"EXTERNAL_WEB_SEARCH_URL": rag_config.get("web_search_url", request.app.state.config.EXTERNAL_WEB_SEARCH_URL),
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY": rag_config.get("web_search_key", request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY),
|
||||
"EXTERNAL_WEB_LOADER_URL": rag_config.get("web_loader_url", request.app.state.config.EXTERNAL_WEB_LOADER_URL),
|
||||
"EXTERNAL_WEB_LOADER_API_KEY": rag_config.get("web_loader_key", request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY),
|
||||
"YOUTUBE_LOADER_LANGUAGE": rag_config.get("youtube_loader_language", request.app.state.config.YOUTUBE_LOADER_LANGUAGE),
|
||||
"YOUTUBE_LOADER_PROXY_URL": rag_config.get("youtube_loader_proxy_url", request.app.state.config.YOUTUBE_LOADER_PROXY_URL),
|
||||
"YOUTUBE_LOADER_TRANSLATION": rag_config.get("youtube_loader_translation", request.app.state.config.YOUTUBE_LOADER_TRANSLATION),
|
||||
@@ -535,6 +554,9 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
||||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
@@ -564,6 +586,10 @@ async def get_rag_config(request: Request, collectionForm: CollectionNameForm, u
|
||||
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
||||
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||||
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
||||
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
||||
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
||||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
@@ -580,6 +606,9 @@ class WebConfig(BaseModel):
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
SEARXNG_QUERY_URL: Optional[str] = None
|
||||
YACY_QUERY_URL: Optional[str] = None
|
||||
YACY_USERNAME: Optional[str] = None
|
||||
YACY_PASSWORD: Optional[str] = None
|
||||
GOOGLE_PSE_API_KEY: Optional[str] = None
|
||||
GOOGLE_PSE_ENGINE_ID: Optional[str] = None
|
||||
BRAVE_SEARCH_API_KEY: Optional[str] = None
|
||||
@@ -609,6 +638,10 @@ class WebConfig(BaseModel):
|
||||
FIRECRAWL_API_KEY: Optional[str] = None
|
||||
FIRECRAWL_API_BASE_URL: Optional[str] = None
|
||||
TAVILY_EXTRACT_DEPTH: Optional[str] = None
|
||||
EXTERNAL_WEB_SEARCH_URL: Optional[str] = None
|
||||
EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None
|
||||
EXTERNAL_WEB_LOADER_URL: Optional[str] = None
|
||||
EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None
|
||||
YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None
|
||||
YOUTUBE_LOADER_PROXY_URL: Optional[str] = None
|
||||
YOUTUBE_LOADER_TRANSLATION: Optional[str] = None
|
||||
@@ -668,9 +701,9 @@ async def update_rag_config(
|
||||
rag_config = knowledge_base.data.get("rag_config", {})
|
||||
|
||||
# Update only the provided fields in the rag_config
|
||||
for field, value in form_data.dict(exclude_unset=True).items():
|
||||
for field, value in form_data.model_dump(exclude_unset=True).items():
|
||||
if field == "web" and value is not None:
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value.dict(exclude_unset=True)}
|
||||
rag_config["web"] = {**rag_config.get("web", {}), **value.model_dump(exclude_unset=True)}
|
||||
else:
|
||||
rag_config[field] = value
|
||||
|
||||
@@ -709,6 +742,7 @@ async def update_rag_config(
|
||||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
|
||||
@@ -821,6 +855,9 @@ async def update_rag_config(
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
|
||||
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
|
||||
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
|
||||
request.app.state.config.YACY_PASSWORD = form_data.web.YACY_PASSWORD
|
||||
request.app.state.config.GOOGLE_PSE_API_KEY = form_data.web.GOOGLE_PSE_API_KEY
|
||||
request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
|
||||
form_data.web.GOOGLE_PSE_ENGINE_ID
|
||||
@@ -867,6 +904,18 @@ async def update_rag_config(
|
||||
request.app.state.config.FIRECRAWL_API_BASE_URL = (
|
||||
form_data.web.FIRECRAWL_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL = (
|
||||
form_data.web.EXTERNAL_WEB_SEARCH_URL
|
||||
)
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY = (
|
||||
form_data.web.EXTERNAL_WEB_SEARCH_API_KEY
|
||||
)
|
||||
request.app.state.config.EXTERNAL_WEB_LOADER_URL = (
|
||||
form_data.web.EXTERNAL_WEB_LOADER_URL
|
||||
)
|
||||
request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY = (
|
||||
form_data.web.EXTERNAL_WEB_LOADER_API_KEY
|
||||
)
|
||||
request.app.state.config.TAVILY_EXTRACT_DEPTH = (
|
||||
form_data.web.TAVILY_EXTRACT_DEPTH
|
||||
)
|
||||
@@ -919,7 +968,10 @@ async def update_rag_config(
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
"YACY_PASSWORD": request.app.state.config.YACY_PASSWORD,
|
||||
"GOOGLE_PSE_API_KEY": request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"GOOGLE_PSE_ENGINE_ID": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
"BRAVE_SEARCH_API_KEY": request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"KAGI_SEARCH_API_KEY": request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
@@ -948,7 +1000,11 @@ async def update_rag_config(
|
||||
"FIRECRAWL_API_KEY": request.app.state.config.FIRECRAWL_API_KEY,
|
||||
"FIRECRAWL_API_BASE_URL": request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||||
"TAVILY_EXTRACT_DEPTH": request.app.state.config.TAVILY_EXTRACT_DEPTH,
|
||||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"EXTERNAL_WEB_SEARCH_URL": request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||
"EXTERNAL_WEB_SEARCH_API_KEY": request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
"EXTERNAL_WEB_LOADER_URL": request.app.state.config.EXTERNAL_WEB_LOADER_URL,
|
||||
"EXTERNAL_WEB_LOADER_API_KEY": request.app.state.config.EXTERNAL_WEB_LOADER_API_KEY,
|
||||
"YOUTUBE_LOADER_LANGUAGE": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"YOUTUBE_LOADER_PROXY_URL": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
"YOUTUBE_LOADER_TRANSLATION": request.app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
@@ -1491,6 +1547,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||
Will look for a search engine API key in environment variables in the following order:
|
||||
- SEARXNG_QUERY_URL
|
||||
- YACY_QUERY_URL + YACY_USERNAME + YACY_PASSWORD
|
||||
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- KAGI_SEARCH_API_KEY
|
||||
@@ -1520,6 +1577,18 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
)
|
||||
else:
|
||||
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
||||
elif engine == "yacy":
|
||||
if request.app.state.config.YACY_QUERY_URL:
|
||||
return search_yacy(
|
||||
request.app.state.config.YACY_QUERY_URL,
|
||||
request.app.state.config.YACY_USERNAME,
|
||||
request.app.state.config.YACY_PASSWORD,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No YACY_QUERY_URL found in environment variables")
|
||||
elif engine == "google_pse":
|
||||
if (
|
||||
request.app.state.config.GOOGLE_PSE_API_KEY
|
||||
@@ -1690,6 +1759,22 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
raise Exception(
|
||||
"No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables"
|
||||
)
|
||||
elif engine == "firecrawl":
|
||||
return search_firecrawl(
|
||||
request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||||
request.app.state.config.FIRECRAWL_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "external":
|
||||
return search_external(
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
@@ -1702,8 +1787,11 @@ async def process_web_search(
|
||||
logging.info(
|
||||
f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.query}"
|
||||
)
|
||||
web_results = search_web(
|
||||
request, request.app.state.config.WEB_SEARCH_ENGINE, form_data.query
|
||||
web_results = await run_in_threadpool(
|
||||
search_web,
|
||||
request,
|
||||
request.app.state.config.WEB_SEARCH_ENGINE,
|
||||
form_data.query,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -1725,8 +1813,8 @@ async def process_web_search(
|
||||
)
|
||||
docs = await loader.aload()
|
||||
urls = [
|
||||
doc.metadata["source"] for doc in docs
|
||||
] # only keep URLs which could be retrieved
|
||||
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
||||
] # only keep URLs
|
||||
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
return {
|
||||
@@ -1746,19 +1834,22 @@ async def process_web_search(
|
||||
collection_names = []
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
if doc and doc.page_content:
|
||||
collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
|
||||
:63
|
||||
]
|
||||
try:
|
||||
collection_name = f"web-search-{calculate_sha256_string(form_data.query + '-' + urls[doc_idx])}"[
|
||||
:63
|
||||
]
|
||||
|
||||
collection_names.append(collection_name)
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
[doc],
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
collection_names.append(collection_name)
|
||||
await run_in_threadpool(
|
||||
save_docs_to_vector_db,
|
||||
request,
|
||||
[doc],
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"error saving doc {doc_idx}: {e}")
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
|
||||
@@ -6,6 +6,7 @@ from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.users import (
|
||||
UserModel,
|
||||
UserListResponse,
|
||||
UserRoleUpdateForm,
|
||||
Users,
|
||||
UserSettings,
|
||||
@@ -33,13 +34,38 @@ router = APIRouter()
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserModel])
|
||||
PAGE_ITEM_COUNT = 10
|
||||
|
||||
|
||||
@router.get("/", response_model=UserListResponse)
|
||||
async def get_users(
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users(skip, limit)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@router.get("/all", response_model=UserListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users()
|
||||
|
||||
|
||||
############################
|
||||
@@ -88,6 +114,8 @@ class ChatPermissions(BaseModel):
|
||||
file_upload: bool = True
|
||||
delete: bool = True
|
||||
edit: bool = True
|
||||
share: bool = True
|
||||
export: bool = True
|
||||
stt: bool = True
|
||||
tts: bool = True
|
||||
call: bool = True
|
||||
@@ -288,6 +316,21 @@ async def update_user_by_id(
|
||||
form_data: UserUpdateForm,
|
||||
session_user=Depends(get_admin_user),
|
||||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id and session_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(user_id)
|
||||
|
||||
if user:
|
||||
@@ -335,6 +378,21 @@ async def update_user_by_id(
|
||||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
# Prevent deletion of the primary admin user
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not verify primary admin status.",
|
||||
)
|
||||
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
|
||||
@@ -346,6 +404,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
|
||||
)
|
||||
|
||||
# Prevent self-deletion
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
|
||||
@@ -192,6 +192,9 @@ async def connect(sid, environ, auth):
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())})
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@sio.on("user-join")
|
||||
@@ -314,16 +317,18 @@ def get_event_emitter(request_info, update_db=True):
|
||||
)
|
||||
)
|
||||
|
||||
for session_id in session_ids:
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
)
|
||||
emit_tasks = [sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info.get("chat_id", None),
|
||||
"message_id": request_info.get("message_id", None),
|
||||
"data": event_data,
|
||||
},
|
||||
to=session_id,
|
||||
)
|
||||
for session_id in session_ids]
|
||||
|
||||
await asyncio.gather(*emit_tasks)
|
||||
|
||||
if update_db:
|
||||
if "type" in event_data and event_data["type"] == "status":
|
||||
|
||||
@@ -3,7 +3,7 @@ import shutil
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, Tuple
|
||||
from typing import BinaryIO, Tuple, Dict
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
@@ -44,7 +44,9 @@ class StorageProvider(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
def upload_file(
|
||||
self, file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -58,7 +60,9 @@ class StorageProvider(ABC):
|
||||
|
||||
class LocalStorageProvider(StorageProvider):
|
||||
@staticmethod
|
||||
def upload_file(file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
def upload_file(
|
||||
file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
contents = file.read()
|
||||
if not contents:
|
||||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||||
@@ -131,12 +135,20 @@ class S3StorageProvider(StorageProvider):
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
def upload_file(
|
||||
self, file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to S3 storage."""
|
||||
_, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
_, file_path = LocalStorageProvider.upload_file(file, filename, tags)
|
||||
tagging = {"TagSet": [{"Key": k, "Value": v} for k, v in tags.items()]}
|
||||
try:
|
||||
s3_key = os.path.join(self.key_prefix, filename)
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||
self.s3_client.put_object_tagging(
|
||||
Bucket=self.bucket_name,
|
||||
Key=s3_key,
|
||||
Tagging=tagging,
|
||||
)
|
||||
return (
|
||||
open(file_path, "rb").read(),
|
||||
"s3://" + self.bucket_name + "/" + s3_key,
|
||||
@@ -207,9 +219,11 @@ class GCSStorageProvider(StorageProvider):
|
||||
self.gcs_client = storage.Client()
|
||||
self.bucket = self.gcs_client.bucket(GCS_BUCKET_NAME)
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
def upload_file(
|
||||
self, file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to GCS storage."""
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
|
||||
try:
|
||||
blob = self.bucket.blob(filename)
|
||||
blob.upload_from_filename(file_path)
|
||||
@@ -277,9 +291,11 @@ class AzureStorageProvider(StorageProvider):
|
||||
self.container_name
|
||||
)
|
||||
|
||||
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
|
||||
def upload_file(
|
||||
self, file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
"""Handles uploading of the file to Azure Blob Storage."""
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename)
|
||||
contents, file_path = LocalStorageProvider.upload_file(file, filename, tags)
|
||||
try:
|
||||
blob_client = self.container_client.get_blob_client(filename)
|
||||
blob_client.upload_blob(contents, overwrite=True)
|
||||
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
class AuditLogEntry:
|
||||
# `Metadata` audit level properties
|
||||
id: str
|
||||
user: dict[str, Any]
|
||||
user: Optional[dict[str, Any]]
|
||||
audit_level: str
|
||||
verb: str
|
||||
request_uri: str
|
||||
@@ -190,21 +190,40 @@ class AuditLoggingMiddleware:
|
||||
finally:
|
||||
await self._log_audit_entry(request, context)
|
||||
|
||||
async def _get_authenticated_user(self, request: Request) -> UserModel:
|
||||
|
||||
async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header
|
||||
user = get_current_user(request, None, get_http_authorization_cred(auth_header))
|
||||
|
||||
return user
|
||||
try:
|
||||
user = get_current_user(
|
||||
request, None, get_http_authorization_cred(auth_header)
|
||||
)
|
||||
return user
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get authenticated user: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _should_skip_auditing(self, request: Request) -> bool:
|
||||
if (
|
||||
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
|
||||
or AUDIT_LOG_LEVEL == "NONE"
|
||||
or not request.headers.get("authorization")
|
||||
):
|
||||
return True
|
||||
|
||||
ALWAYS_LOG_ENDPOINTS = {
|
||||
"/api/v1/auths/signin",
|
||||
"/api/v1/auths/signout",
|
||||
"/api/v1/auths/signup",
|
||||
}
|
||||
path = request.url.path.lower()
|
||||
for endpoint in ALWAYS_LOG_ENDPOINTS:
|
||||
if path.startswith(endpoint):
|
||||
return False # Do NOT skip logging for auth endpoints
|
||||
|
||||
# Skip logging if the request is not authenticated
|
||||
if not request.headers.get("authorization"):
|
||||
return True
|
||||
|
||||
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
|
||||
pattern = re.compile(
|
||||
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
|
||||
@@ -231,17 +250,32 @@ class AuditLoggingMiddleware:
|
||||
try:
|
||||
user = await self._get_authenticated_user(request)
|
||||
|
||||
user = (
|
||||
user.model_dump(include={"id", "name", "email", "role"}) if user else {}
|
||||
)
|
||||
|
||||
request_body = context.request_body.decode("utf-8", errors="replace")
|
||||
response_body = context.response_body.decode("utf-8", errors="replace")
|
||||
|
||||
# Redact sensitive information
|
||||
if "password" in request_body:
|
||||
request_body = re.sub(
|
||||
r'"password":\s*"(.*?)"',
|
||||
'"password": "********"',
|
||||
request_body,
|
||||
)
|
||||
|
||||
entry = AuditLogEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
user=user.model_dump(include={"id", "name", "email", "role"}),
|
||||
user=user,
|
||||
audit_level=self.audit_level.value,
|
||||
verb=request.method,
|
||||
request_uri=str(request.url),
|
||||
response_status_code=context.metadata.get("response_status_code", None),
|
||||
source_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
request_object=context.request_body.decode("utf-8", errors="replace"),
|
||||
response_object=context.response_body.decode("utf-8", errors="replace"),
|
||||
request_object=request_body,
|
||||
response_object=response_body,
|
||||
)
|
||||
|
||||
self.audit_logger.write(entry)
|
||||
|
||||
@@ -50,7 +50,7 @@ class JupyterCodeExecuter:
|
||||
self.password = password
|
||||
self.timeout = timeout
|
||||
self.kernel_id = ""
|
||||
self.session = aiohttp.ClientSession(base_url=self.base_url)
|
||||
self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
|
||||
self.params = {}
|
||||
self.result = ResultModel()
|
||||
|
||||
|
||||
@@ -888,16 +888,20 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
# If context is not empty, insert it into the messages
|
||||
if len(sources) > 0:
|
||||
context_string = ""
|
||||
citated_file_idx = {}
|
||||
for _, source in enumerate(sources, 1):
|
||||
citation_idx = {}
|
||||
for source in sources:
|
||||
if "document" in source:
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
file_id = doc_meta.get("file_id")
|
||||
if file_id not in citated_file_idx:
|
||||
citated_file_idx[file_id] = len(citated_file_idx) + 1
|
||||
context_string += f'<source id="{citated_file_idx[file_id]}">{doc_context}</source>\n'
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
or "N/A"
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -1133,7 +1137,7 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if get_active_status_by_user_id(user.id) is None:
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
@@ -1671,6 +1675,15 @@ async def process_chat_response(
|
||||
|
||||
if current_response_tool_call is None:
|
||||
# Add the new tool call
|
||||
delta_tool_call.setdefault(
|
||||
"function", {}
|
||||
)
|
||||
delta_tool_call[
|
||||
"function"
|
||||
].setdefault("name", "")
|
||||
delta_tool_call[
|
||||
"function"
|
||||
].setdefault("arguments", "")
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
@@ -2215,7 +2228,7 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if get_active_status_by_user_id(user.id) is None:
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
|
||||
@@ -15,7 +15,7 @@ from starlette.responses import RedirectResponse
|
||||
|
||||
from open_webui.models.auths import Auths
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm
|
||||
from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
@@ -23,6 +23,7 @@ from open_webui.config import (
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT,
|
||||
ENABLE_OAUTH_GROUP_MANAGEMENT,
|
||||
ENABLE_OAUTH_GROUP_CREATION,
|
||||
OAUTH_ROLES_CLAIM,
|
||||
OAUTH_GROUPS_CLAIM,
|
||||
OAUTH_EMAIL_CLAIM,
|
||||
@@ -57,6 +58,7 @@ auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
|
||||
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
|
||||
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT
|
||||
auth_manager_config.ENABLE_OAUTH_GROUP_CREATION = ENABLE_OAUTH_GROUP_CREATION
|
||||
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||
auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM
|
||||
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||
@@ -152,6 +154,51 @@ class OAuthManager:
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
|
||||
# Create groups if they don't exist and creation is enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||
log.debug("Checking for missing groups to create...")
|
||||
all_group_names = {g.name for g in all_available_groups}
|
||||
groups_created = False
|
||||
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
|
||||
admin_user = Users.get_admin_user()
|
||||
creator_id = admin_user.id if admin_user else user.id
|
||||
log.debug(f"Using creator ID {creator_id} for potential group creation.")
|
||||
|
||||
for group_name in user_oauth_groups:
|
||||
if group_name not in all_group_names:
|
||||
log.info(
|
||||
f"Group '{group_name}' not found via OAuth claim. Creating group..."
|
||||
)
|
||||
try:
|
||||
new_group_form = GroupForm(
|
||||
name=group_name,
|
||||
description=f"Group '{group_name}' created automatically via OAuth.",
|
||||
permissions=default_permissions, # Use default permissions from function args
|
||||
user_ids=[], # Start with no users, user will be added later by subsequent logic
|
||||
)
|
||||
# Use determined creator ID (admin or fallback to current user)
|
||||
created_group = Groups.insert_new_group(
|
||||
creator_id, new_group_form
|
||||
)
|
||||
if created_group:
|
||||
log.info(
|
||||
f"Successfully created group '{group_name}' with ID {created_group.id} using creator ID {creator_id}"
|
||||
)
|
||||
groups_created = True
|
||||
# Add to local set to prevent duplicate creation attempts in this run
|
||||
all_group_names.add(group_name)
|
||||
else:
|
||||
log.error(
|
||||
f"Failed to create group '{group_name}' via OAuth."
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error creating group '{group_name}' via OAuth: {e}")
|
||||
|
||||
# Refresh the list of all available groups if any were created
|
||||
if groups_created:
|
||||
all_available_groups = Groups.get_groups()
|
||||
log.debug("Refreshed list of all available groups after creation.")
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
log.debug(f"User oauth groups: {user_oauth_groups}")
|
||||
log.debug(f"User's current groups: {[g.name for g in user_current_groups]}")
|
||||
@@ -257,7 +304,7 @@ class OAuthManager:
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
) as resp:
|
||||
@@ -339,7 +386,7 @@ class OAuthManager:
|
||||
get_kwargs["headers"] = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
picture_url, **get_kwargs
|
||||
) as resp:
|
||||
|
||||
@@ -157,7 +157,8 @@ def load_function_module_by_id(function_id, content=None):
|
||||
raise Exception("No Function class found in the module")
|
||||
except Exception as e:
|
||||
log.error(f"Error loading module: {function_id}: {e}")
|
||||
del sys.modules[module_name] # Cleanup by removing the module in case of error
|
||||
# Cleanup by removing the module in case of error
|
||||
del sys.modules[module_name]
|
||||
|
||||
Functions.update_function_by_id(function_id, {"is_active": False})
|
||||
raise e
|
||||
@@ -182,3 +183,32 @@ def install_frontmatter_requirements(requirements: str):
|
||||
|
||||
else:
|
||||
log.info("No requirements found in frontmatter.")
|
||||
|
||||
|
||||
def install_tool_and_function_dependencies():
|
||||
"""
|
||||
Install all dependencies for all admin tools and active functions.
|
||||
|
||||
By first collecting all dependencies from the frontmatter of each tool and function,
|
||||
and then installing them using pip. Duplicates or similar version specifications are
|
||||
handled by pip as much as possible.
|
||||
"""
|
||||
function_list = Functions.get_functions(active_only=True)
|
||||
tool_list = Tools.get_tools()
|
||||
|
||||
all_dependencies = ""
|
||||
try:
|
||||
for function in function_list:
|
||||
frontmatter = extract_frontmatter(replace_imports(function.content))
|
||||
if dependencies := frontmatter.get("requirements"):
|
||||
all_dependencies += f"{dependencies}, "
|
||||
for tool in tool_list:
|
||||
# Only install requirements for admin tools
|
||||
if tool.user.role == "admin":
|
||||
frontmatter = extract_frontmatter(replace_imports(tool.content))
|
||||
if dependencies := frontmatter.get("requirements"):
|
||||
all_dependencies += f"{dependencies}, "
|
||||
|
||||
install_frontmatter_requirements(all_dependencies.strip(", "))
|
||||
except Exception as e:
|
||||
log.error(f"Error installing requirements: {e}")
|
||||
|
||||
@@ -36,7 +36,10 @@ from langchain_core.utils.function_calling import (
|
||||
from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
|
||||
import copy
|
||||
|
||||
@@ -276,8 +279,8 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
|
||||
|
||||
docstring = func.__doc__
|
||||
|
||||
description = parse_description(docstring)
|
||||
function_descriptions = parse_docstring(docstring)
|
||||
function_description = parse_description(docstring)
|
||||
function_param_descriptions = parse_docstring(docstring)
|
||||
|
||||
field_defs = {}
|
||||
for name, param in parameters.items():
|
||||
@@ -285,15 +288,15 @@ def convert_function_to_pydantic_model(func: Callable) -> type[BaseModel]:
|
||||
type_hint = type_hints.get(name, Any)
|
||||
default_value = param.default if param.default is not param.empty else ...
|
||||
|
||||
description = function_descriptions.get(name, None)
|
||||
param_description = function_param_descriptions.get(name, None)
|
||||
|
||||
if description:
|
||||
field_defs[name] = type_hint, Field(default_value, description=description)
|
||||
if param_description:
|
||||
field_defs[name] = type_hint, Field(default_value, description=param_description)
|
||||
else:
|
||||
field_defs[name] = type_hint, default_value
|
||||
|
||||
model = create_model(func.__name__, **field_defs)
|
||||
model.__doc__ = description
|
||||
model.__doc__ = function_description
|
||||
|
||||
return model
|
||||
|
||||
@@ -371,51 +374,64 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
||||
|
||||
for path, methods in openapi_spec.get("paths", {}).items():
|
||||
for method, operation in methods.items():
|
||||
tool = {
|
||||
"type": "function",
|
||||
"name": operation.get("operationId"),
|
||||
"description": operation.get(
|
||||
"description", operation.get("summary", "No description available.")
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
# Extract path and query parameters
|
||||
for param in operation.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
param_schema = param.get("schema", {})
|
||||
tool["parameters"]["properties"][param_name] = {
|
||||
"type": param_schema.get("type"),
|
||||
"description": param_schema.get("description", ""),
|
||||
if operation.get("operationId"):
|
||||
tool = {
|
||||
"type": "function",
|
||||
"name": operation.get("operationId"),
|
||||
"description": operation.get(
|
||||
"description",
|
||||
operation.get("summary", "No description available."),
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
if param.get("required"):
|
||||
tool["parameters"]["required"].append(param_name)
|
||||
|
||||
# Extract and resolve requestBody if available
|
||||
request_body = operation.get("requestBody")
|
||||
if request_body:
|
||||
content = request_body.get("content", {})
|
||||
json_schema = content.get("application/json", {}).get("schema")
|
||||
if json_schema:
|
||||
resolved_schema = resolve_schema(
|
||||
json_schema, openapi_spec.get("components", {})
|
||||
)
|
||||
|
||||
if resolved_schema.get("properties"):
|
||||
tool["parameters"]["properties"].update(
|
||||
resolved_schema["properties"]
|
||||
# Extract path and query parameters
|
||||
for param in operation.get("parameters", []):
|
||||
param_name = param["name"]
|
||||
param_schema = param.get("schema", {})
|
||||
description = param_schema.get("description", "")
|
||||
if not description:
|
||||
description = param.get("description") or ""
|
||||
if param_schema.get("enum") and isinstance(
|
||||
param_schema.get("enum"), list
|
||||
):
|
||||
description += (
|
||||
f". Possible values: {', '.join(param_schema.get('enum'))}"
|
||||
)
|
||||
if "required" in resolved_schema:
|
||||
tool["parameters"]["required"] = list(
|
||||
set(
|
||||
tool["parameters"]["required"]
|
||||
+ resolved_schema["required"]
|
||||
)
|
||||
)
|
||||
elif resolved_schema.get("type") == "array":
|
||||
tool["parameters"] = resolved_schema # special case for array
|
||||
tool["parameters"]["properties"][param_name] = {
|
||||
"type": param_schema.get("type"),
|
||||
"description": description,
|
||||
}
|
||||
if param.get("required"):
|
||||
tool["parameters"]["required"].append(param_name)
|
||||
|
||||
tool_payload.append(tool)
|
||||
# Extract and resolve requestBody if available
|
||||
request_body = operation.get("requestBody")
|
||||
if request_body:
|
||||
content = request_body.get("content", {})
|
||||
json_schema = content.get("application/json", {}).get("schema")
|
||||
if json_schema:
|
||||
resolved_schema = resolve_schema(
|
||||
json_schema, openapi_spec.get("components", {})
|
||||
)
|
||||
|
||||
if resolved_schema.get("properties"):
|
||||
tool["parameters"]["properties"].update(
|
||||
resolved_schema["properties"]
|
||||
)
|
||||
if "required" in resolved_schema:
|
||||
tool["parameters"]["required"] = list(
|
||||
set(
|
||||
tool["parameters"]["required"]
|
||||
+ resolved_schema["required"]
|
||||
)
|
||||
)
|
||||
elif resolved_schema.get("type") == "array":
|
||||
tool["parameters"] = (
|
||||
resolved_schema # special case for array
|
||||
)
|
||||
|
||||
tool_payload.append(tool)
|
||||
|
||||
return tool_payload
|
||||
|
||||
@@ -431,8 +447,10 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
error = None
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_body = await response.json()
|
||||
raise Exception(error_body)
|
||||
@@ -573,19 +591,26 @@ async def execute_tool_server(
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
request_method = getattr(session, http_method.lower())
|
||||
|
||||
if http_method in ["post", "put", "patch"]:
|
||||
async with request_method(
|
||||
final_url, json=body_params, headers=headers
|
||||
final_url,
|
||||
json=body_params,
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
text = await response.text()
|
||||
raise Exception(f"HTTP error {response.status}: {text}")
|
||||
return await response.json()
|
||||
else:
|
||||
async with request_method(final_url, headers=headers) as response:
|
||||
async with request_method(
|
||||
final_url,
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
text = await response.text()
|
||||
raise Exception(f"HTTP error {response.status}: {text}")
|
||||
|
||||
@@ -31,7 +31,7 @@ APScheduler==3.10.4
|
||||
|
||||
RestrictedPython==8.0
|
||||
|
||||
loguru==0.7.2
|
||||
loguru==0.7.3
|
||||
asgiref==3.8.1
|
||||
|
||||
# AI libraries
|
||||
@@ -40,8 +40,8 @@ anthropic
|
||||
google-generativeai==0.8.4
|
||||
tiktoken
|
||||
|
||||
langchain==0.3.19
|
||||
langchain-community==0.3.18
|
||||
langchain==0.3.24
|
||||
langchain-community==0.3.23
|
||||
|
||||
fake-useragent==2.1.0
|
||||
chromadb==0.6.3
|
||||
@@ -49,11 +49,11 @@ pymilvus==2.5.0
|
||||
qdrant-client~=1.12.0
|
||||
opensearch-py==2.8.0
|
||||
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||
elasticsearch==8.17.1
|
||||
|
||||
elasticsearch==9.0.1
|
||||
pinecone==6.0.2
|
||||
|
||||
transformers
|
||||
sentence-transformers==3.3.1
|
||||
sentence-transformers==4.1.0
|
||||
accelerate
|
||||
colbert-ai==0.2.21
|
||||
einops==0.8.1
|
||||
@@ -81,7 +81,7 @@ azure-ai-documentintelligence==1.0.0
|
||||
|
||||
pillow==11.1.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
rapidocr-onnxruntime==1.3.24
|
||||
rapidocr-onnxruntime==1.4.4
|
||||
rank-bm25==0.2.2
|
||||
|
||||
onnxruntime==1.20.1
|
||||
@@ -107,7 +107,7 @@ google-auth-oauthlib
|
||||
|
||||
## Tests
|
||||
docker~=7.1.0
|
||||
pytest~=8.3.2
|
||||
pytest~=8.3.5
|
||||
pytest-docker~=3.1.1
|
||||
|
||||
googleapis-common-protos==1.63.2
|
||||
|
||||
Reference in New Issue
Block a user