Merge branch 'open-webui:main' into next

This commit is contained in:
karlorz
2025-06-12 10:05:22 +08:00
committed by GitHub
156 changed files with 5307 additions and 1846 deletions

View File

@@ -347,6 +347,24 @@ MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
)
MICROSOFT_CLIENT_LOGIN_BASE_URL = PersistentConfig(
"MICROSOFT_CLIENT_LOGIN_BASE_URL",
"oauth.microsoft.login_base_url",
os.environ.get(
"MICROSOFT_CLIENT_LOGIN_BASE_URL", "https://login.microsoftonline.com"
),
)
MICROSOFT_CLIENT_PICTURE_URL = PersistentConfig(
"MICROSOFT_CLIENT_PICTURE_URL",
"oauth.microsoft.picture_url",
os.environ.get(
"MICROSOFT_CLIENT_PICTURE_URL",
"https://graph.microsoft.com/v1.0/me/photo/$value",
),
)
MICROSOFT_OAUTH_SCOPE = PersistentConfig(
"MICROSOFT_OAUTH_SCOPE",
"oauth.microsoft.scope",
@@ -542,7 +560,7 @@ def load_oauth_providers():
name="microsoft",
client_id=MICROSOFT_CLIENT_ID.value,
client_secret=MICROSOFT_CLIENT_SECRET.value,
server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
},
@@ -551,7 +569,7 @@ def load_oauth_providers():
OAUTH_PROVIDERS["microsoft"] = {
"redirect_uri": MICROSOFT_REDIRECT_URI.value,
"picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value",
"picture_url": MICROSOFT_CLIENT_PICTURE_URL.value,
"register": microsoft_oauth_register,
}
@@ -901,9 +919,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig(
####################################
WEBUI_URL = PersistentConfig(
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
)
WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", ""))
ENABLE_SIGNUP = PersistentConfig(
@@ -1247,12 +1263,6 @@ if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str):
THREAD_POOL_SIZE = None
def validate_cors_origins(origins):
for origin in origins:
if origin != "*":
validate_cors_origin(origin)
def validate_cors_origin(origin):
parsed_url = urlparse(origin)
@@ -1272,16 +1282,17 @@ def validate_cors_origin(origin):
# To test CORS_ALLOW_ORIGIN locally, you can set something like
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
# in your .env file depending on your frontend port, 5173 in this case.
CORS_ALLOW_ORIGIN = os.environ.get(
"CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080"
).split(";")
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
if "*" in CORS_ALLOW_ORIGIN:
if CORS_ALLOW_ORIGIN == ["*"]:
log.warning(
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
)
validate_cors_origins(CORS_ALLOW_ORIGIN)
else:
# You have to pick between a single wildcard or a list of origins.
# Doing both will result in CORS errors in the browser.
for origin in CORS_ALLOW_ORIGIN:
validate_cors_origin(origin)
class BannerModel(BaseModel):
@@ -1413,6 +1424,35 @@ Strictly return in JSON format:
{{MESSAGES:END:6}}
</chat_history>"""
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
"task.follow_up.prompt_template",
os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
### Guidelines:
- Write all follow-up questions from the users point of view, directed to the assistant.
- Make questions concise, clear, and directly related to the discussed topic(s).
- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
- Use the conversation's primary language; default to English if multilingual.
- Response must be a JSON array of strings, no extra text or formatting.
### Output:
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
"ENABLE_FOLLOW_UP_GENERATION",
"task.follow_up.enable",
os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
)
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
@@ -1786,6 +1826,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
)
PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
)
# Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
@@ -1946,6 +1993,40 @@ DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
)
DOCLING_PICTURE_DESCRIPTION_MODE = PersistentConfig(
"DOCLING_PICTURE_DESCRIPTION_MODE",
"rag.docling_picture_description_mode",
os.getenv("DOCLING_PICTURE_DESCRIPTION_MODE", ""),
)
docling_picture_description_local = os.getenv("DOCLING_PICTURE_DESCRIPTION_LOCAL", "")
try:
docling_picture_description_local = json.loads(docling_picture_description_local)
except json.JSONDecodeError:
docling_picture_description_local = {}
DOCLING_PICTURE_DESCRIPTION_LOCAL = PersistentConfig(
"DOCLING_PICTURE_DESCRIPTION_LOCAL",
"rag.docling_picture_description_local",
docling_picture_description_local,
)
docling_picture_description_api = os.getenv("DOCLING_PICTURE_DESCRIPTION_API", "")
try:
docling_picture_description_api = json.loads(docling_picture_description_api)
except json.JSONDecodeError:
docling_picture_description_api = {}
DOCLING_PICTURE_DESCRIPTION_API = PersistentConfig(
"DOCLING_PICTURE_DESCRIPTION_API",
"rag.docling_picture_description_api",
docling_picture_description_api,
)
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
"DOCUMENT_INTELLIGENCE_ENDPOINT",
"rag.document_intelligence_endpoint",
@@ -2445,6 +2526,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
os.getenv("PERPLEXITY_API_KEY", ""),
)
PERPLEXITY_MODEL = PersistentConfig(
"PERPLEXITY_MODEL",
"rag.web.search.perplexity_model",
os.getenv("PERPLEXITY_MODEL", "sonar"),
)
PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
"PERPLEXITY_SEARCH_CONTEXT_USAGE",
"rag.web.search.perplexity_search_context_usage",
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
)
SOUGOU_API_SID = PersistentConfig(
"SOUGOU_API_SID",
"rag.web.search.sougou_api_sid",

View File

@@ -111,6 +111,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"

View File

@@ -5,6 +5,7 @@ import os
import pkgutil
import sys
import shutil
from uuid import uuid4
from pathlib import Path
import markdown
@@ -130,6 +131,7 @@ else:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
# Function to parse each section

View File

@@ -25,6 +25,7 @@ from open_webui.socket.main import (
)
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
@@ -227,12 +228,7 @@ async def generate_function_chat_completion(
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
}

View File

@@ -8,6 +8,8 @@ import shutil
import sys
import time
import random
from uuid import uuid4
from contextlib import asynccontextmanager
from urllib.parse import urlencode, parse_qs, urlparse
@@ -19,6 +21,7 @@ from aiocache import cached
import aiohttp
import anyio.to_thread
import requests
from redis import Redis
from fastapi import (
@@ -37,7 +40,7 @@ from fastapi import (
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from starlette_compress import CompressMiddleware
@@ -231,6 +234,9 @@ from open_webui.config import (
DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION,
DOCLING_PICTURE_DESCRIPTION_MODE,
DOCLING_PICTURE_DESCRIPTION_LOCAL,
DOCLING_PICTURE_DESCRIPTION_API,
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
@@ -268,6 +274,8 @@ from open_webui.config import (
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
PERPLEXITY_API_KEY,
PERPLEXITY_MODEL,
PERPLEXITY_SEARCH_CONTEXT_USAGE,
SOUGOU_API_SID,
SOUGOU_API_SK,
KAGI_SEARCH_API_KEY,
@@ -359,10 +367,12 @@ from open_webui.config import (
TASK_MODEL_EXTERNAL,
ENABLE_TAGS_GENERATION,
ENABLE_TITLE_GENERATION,
ENABLE_FOLLOW_UP_GENERATION,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -384,6 +394,7 @@ from open_webui.env import (
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
INSTANCE_ID,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
@@ -411,6 +422,7 @@ from open_webui.utils.chat import (
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
from open_webui.utils.embeddings import generate_embeddings
from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access
@@ -424,8 +436,10 @@ from open_webui.utils.auth import (
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import OAuthManager
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
from open_webui.tasks import (
redis_task_command_listener,
list_task_ids_by_chat_id,
stop_task,
list_tasks,
@@ -477,7 +491,9 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.instance_id = INSTANCE_ID
start_logger()
if RESET_CONFIG_ON_START:
reset_config()
@@ -489,6 +505,19 @@ async def lifespan(app: FastAPI):
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
app.state.redis = get_redis_connection(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
async_mode=True,
)
if app.state.redis is not None:
app.state.redis_task_command_listener = asyncio.create_task(
redis_task_command_listener(app)
)
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = THREAD_POOL_SIZE
@@ -497,6 +526,9 @@ async def lifespan(app: FastAPI):
yield
if hasattr(app.state, "redis_task_command_listener"):
app.state.redis_task_command_listener.cancel()
app = FastAPI(
title="Open WebUI",
@@ -508,10 +540,12 @@ app = FastAPI(
oauth_manager = OAuthManager(app)
app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
)
app.state.redis = None
app.state.WEBUI_NAME = WEBUI_NAME
app.state.LICENSE_METADATA = None
@@ -696,6 +730,9 @@ app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
@@ -771,6 +808,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
@@ -959,6 +998,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@@ -966,6 +1006,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -1197,6 +1240,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
return {"data": models}
##################################
# Embeddings
##################################
@app.post("/api/embeddings")
async def embeddings(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
"""
OpenAI-compatible embeddings endpoint.
This handler:
- Performs user/model checks and dispatches to the correct backend.
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
Args:
request (Request): Request context.
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
user (UserModel): Authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
"""
# Make sure models are loaded in app state
if not request.app.state.MODELS:
await get_all_models(request, user=user)
# Use generic dispatcher in utils.embeddings
return await generate_embeddings(request, form_data, user)
@app.post("/api/chat/completions")
async def chat_completion(
request: Request,
@@ -1338,26 +1412,30 @@ async def chat_action(
@app.post("/api/tasks/stop/{task_id}")
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
async def stop_task_endpoint(
request: Request, task_id: str, user=Depends(get_verified_user)
):
try:
result = await stop_task(task_id)
result = await stop_task(request, task_id)
return result
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
@app.get("/api/tasks")
async def list_tasks_endpoint(user=Depends(get_verified_user)):
return {"tasks": list_tasks()}
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
return {"tasks": await list_tasks(request)}
@app.get("/api/tasks/chat/{chat_id}")
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
async def list_tasks_by_chat_id_endpoint(
request: Request, chat_id: str, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id(chat_id)
if chat is None or chat.user_id != user.id:
return {"task_ids": []}
task_ids = list_task_ids_by_chat_id(chat_id)
task_ids = await list_task_ids_by_chat_id(request, chat_id)
print(f"Task IDs for chat {chat_id}: {task_ids}")
return {"task_ids": task_ids}
@@ -1628,7 +1706,20 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
@app.get("/cache/{path:path}")
async def serve_cache_file(
path: str,
user=Depends(get_verified_user),
):
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
# prevent path traversal
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
raise HTTPException(status_code=404, detail="File not found")
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path)
def swagger_ui_html(*args, **kwargs):

View File

@@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel):
role: str
name: str
email: str
profile_image_url: str
@@ -369,7 +370,7 @@ class UsersTable:
except Exception:
return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
try:
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})

View File

@@ -2,6 +2,7 @@ import requests
import logging
import ftfy
import sys
import json
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
@@ -76,7 +77,6 @@ known_source_ext = [
"swift",
"vue",
"svelte",
"msg",
"ex",
"exs",
"erl",
@@ -147,17 +147,32 @@ class DoclingLoader:
)
}
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
params = {"image_export_mode": "placeholder", "table_mode": "accurate"}
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
if self.params.get("do_picture_description"):
params["do_picture_description"] = self.params.get(
"do_picture_description"
)
picture_description_mode = self.params.get(
"picture_description_mode", ""
).lower()
if picture_description_mode == "local" and self.params.get(
"picture_description_local", {}
):
params["picture_description_local"] = self.params.get(
"picture_description_local", {}
)
elif picture_description_mode == "api" and self.params.get(
"picture_description_api", {}
):
params["picture_description_api"] = self.params.get(
"picture_description_api", {}
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
params["ocr_engine"] = self.params.get("ocr_engine")
params["ocr_lang"] = [
@@ -285,17 +300,20 @@ class Loader:
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
params = {}
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
file_path=file_path,
mime_type=file_content_type,
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},
params=params,
)
elif (
self.engine == "document_intelligence"

View File

@@ -20,6 +20,14 @@ class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
Performance Optimizations:
- Differentiated timeouts for different operations
- Intelligent retry logic with exponential backoff
- Memory-efficient file streaming for large files
- Connection pooling and keepalive optimization
- Semaphore-based concurrency control for batch processing
- Enhanced error handling with retryable error classification
"""
BASE_API_URL = "https://api.mistral.ai/v1"
@@ -53,17 +61,40 @@ class MistralLoader:
self.max_retries = max_retries
self.debug = enable_debug_logging
# Pre-compute file info for performance
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
}
def _debug_log(self, message: str, *args) -> None:
"""Conditional debug logging for performance."""
"""
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
Only processes debug messages when debug mode is enabled, avoiding
string formatting overhead in production environments.
"""
if self.debug:
log.debug(message, *args)
@@ -115,53 +146,118 @@ class MistralLoader:
log.error(f"Unexpected error processing response: {e}")
raise
def _is_retryable_error(self, error: Exception) -> bool:
"""
ENHANCEMENT: Intelligent error classification for retry logic.
Determines if an error is retryable based on its type and status code.
This prevents wasting time retrying errors that will never succeed
(like authentication errors) while ensuring transient errors are retried.
Retryable errors:
- Network connection errors (temporary network issues)
- Timeouts (server might be temporarily overloaded)
- Server errors (5xx status codes - server-side issues)
- Rate limiting (429 status - temporary throttling)
Non-retryable errors:
- Authentication errors (401, 403 - won't fix with retry)
- Bad request errors (400 - malformed request)
- Not found errors (404 - resource doesn't exist)
"""
if isinstance(error, requests.exceptions.ConnectionError):
return True # Network issues are usually temporary
if isinstance(error, requests.exceptions.Timeout):
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
return False # All other errors are non-retryable
def _retry_request_sync(self, request_func, *args, **kwargs):
"""Synchronous retry logic with exponential backoff."""
"""
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
Uses exponential backoff with jitter to avoid thundering herd problems.
The wait time increases exponentially but is capped at 30 seconds to
prevent excessive delays. Only retries errors that are likely to succeed
on subsequent attempts.
"""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except (requests.exceptions.RequestException, Exception) as e:
if attempt == self.max_retries - 1:
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
wait_time = (2**attempt) + 0.5
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""Async retry logic with exponential backoff."""
"""
ENHANCEMENT: Async retry logic with intelligent error classification.
Async version of retry logic that doesn't block the event loop during
wait periods. Uses the same exponential backoff strategy as sync version.
"""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
if attempt == self.max_retries - 1:
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
wait_time = (2**attempt) + 0.5
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
await asyncio.sleep(wait_time) # Non-blocking wait
def _upload_file(self) -> str:
"""Uploads the file to Mistral for OCR processing (sync version)."""
"""
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
Uploads the file to Mistral for OCR processing (sync version).
Uses context manager for file handling to ensure proper resource cleanup.
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
files = {"file": (file_name, f, "application/pdf")}
files = {"file": (self.file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
timeout=self.timeout,
timeout=self.upload_timeout, # Use specialized upload timeout
stream=False, # Keep as False for this endpoint
)
return self._handle_response(response)
@@ -209,7 +305,7 @@ class MistralLoader:
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
) as response:
return await self._handle_response_async(response)
@@ -231,7 +327,7 @@ class MistralLoader:
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.timeout
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
return self._handle_response(response)
@@ -261,7 +357,7 @@ class MistralLoader:
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.timeout),
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
) as response:
return await self._handle_response_async(response)
@@ -294,7 +390,7 @@ class MistralLoader:
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.timeout
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
return self._handle_response(response)
@@ -336,7 +432,7 @@ class MistralLoader:
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
) as response:
ocr_response = await self._handle_response_async(response)
@@ -353,7 +449,9 @@ class MistralLoader:
url = f"{self.BASE_API_URL}/files/{file_id}"
try:
response = requests.delete(url, headers=self.headers, timeout=30)
response = requests.delete(
url, headers=self.headers, timeout=self.cleanup_timeout
)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
@@ -372,7 +470,7 @@ class MistralLoader:
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=30
total=self.cleanup_timeout
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
@@ -388,29 +486,39 @@ class MistralLoader:
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=10, # Total connection limit
limit_per_host=5, # Per-host connection limit
ttl_dns_cache=300, # DNS cache TTL
limit=20, # Increased total connection limit for better throughput
limit_per_host=10, # Increased per-host limit for API endpoints
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
use_dns_cache=True,
keepalive_timeout=30,
keepalive_timeout=60, # Increased keepalive for connection reuse
enable_cleanup_closed=True,
force_close=False, # Allow connection reuse
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
)
timeout = aiohttp.ClientTimeout(
total=self.timeout,
connect=30, # Connection timeout
sock_read=60, # Socket read timeout
)
async with aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self.timeout),
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata."""
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found", metadata={"error": "no_pages"}
page_content="No text content found",
metadata={"error": "no_pages", "file_name": self.file_name},
)
]
@@ -418,41 +526,44 @@ class MistralLoader:
total_pages = len(pages_data)
skipped_pages = 0
# Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
# Clean up content efficiently
cleaned_content = (
page_content.strip()
if isinstance(page_content, str)
else str(page_content)
)
if cleaned_content: # Only add non-empty pages
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
},
)
)
else:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
else:
if page_content is None or page_index is None:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
)
continue
# Clean up content efficiently with early exit for empty content
if isinstance(page_content, str):
cleaned_content = page_content.strip()
else:
cleaned_content = str(page_content).strip()
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
continue
# Create document with optimized metadata
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
},
)
)
if skipped_pages > 0:
log.info(
@@ -467,7 +578,11 @@ class MistralLoader:
return [
Document(
page_content="No valid text content found in document",
metadata={"error": "no_valid_pages", "total_pages": total_pages},
metadata={
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
},
)
]
@@ -585,12 +700,14 @@ class MistralLoader:
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
Process multiple files concurrently for maximum performance.
Process multiple files concurrently with controlled concurrency.
Args:
loaders: List of MistralLoader instances
max_concurrent: Maximum number of concurrent requests
Returns:
List of document lists, one for each loader
@@ -598,11 +715,20 @@ class MistralLoader:
if not loaders:
return []
log.info(f"Starting concurrent processing of {len(loaders)} files")
log.info(
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
start_time = time.time()
# Process all files concurrently
tasks = [loader.load_async() for loader in loaders]
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async with semaphore:
return await loader.load_async()
# Process all files with controlled concurrency
tasks = [process_with_semaphore(loader) for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
@@ -624,10 +750,18 @@ class MistralLoader:
else:
processed_results.append(result)
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
failure_count = len(results) - success_count
log.info(
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
f"Batch processing completed in {total_time:.2f}s: "
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
)
return processed_results

View File

@@ -1,4 +1,5 @@
import logging
from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
@@ -93,7 +94,6 @@ class YoutubeLoader:
"http": self.proxy_url,
"https": self.proxy_url,
}
# Don't log complete URL because it might contain secrets
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else:
youtube_proxies = None
@@ -110,11 +110,37 @@ class YoutubeLoader:
for lang in self.language:
try:
transcript = transcript_list.find_transcript([lang])
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
pass
log.debug(f"Found transcript for language '{lang}'")
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
try:
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
except ParseError:
log.debug(f"Empty or invalid transcript for language '{lang}'")
continue
if not transcript_pieces:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_text = " ".join(
map(
lambda transcript_piece: transcript_piece.text.strip(" "),
lambda transcript_piece: (
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
),
transcript_pieces,
)
)
@@ -131,6 +157,4 @@ class YoutubeLoader:
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))

View File

@@ -1,12 +1,16 @@
from typing import Optional, List, Dict, Any
import logging
import json
from sqlalchemy import (
func,
literal,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
select,
text,
Text,
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
SearchResult,
GetResult,
)
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
from open_webui.config import (
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
)
from open_webui.env import SRC_LOG_LEVELS
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base):
__tablename__ = "document_chunk"
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
if PGVECTOR_PGCRYPTO:
text = Column(LargeBinary, nullable=True)
vmetadata = Column(LargeBinary, nullable=True)
else:
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient(VectorDBBase):
@@ -70,6 +92,15 @@ class PgvectorClient(VectorDBBase):
# Ensure the pgvector extension is available
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
if PGVECTOR_PGCRYPTO:
# Ensure the pgcrypto extension is available for encryption
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
if not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
)
# Check vector length consistency
self.check_vector_length()
@@ -147,44 +178,39 @@ class PgvectorClient(VectorDBBase):
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
# Use raw SQL for BYTEA/pgcrypto
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata::text, :key)
)
ON CONFLICT (id) DO NOTHING
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata": json.dumps(item["metadata"]),
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
else:
self.session.commit()
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
@@ -192,11 +218,78 @@ class PgvectorClient(VectorDBBase):
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata::text, :key)
)
ON CONFLICT (id) DO UPDATE SET
vector = EXCLUDED.vector,
collection_name = EXCLUDED.collection_name,
text = EXCLUDED.text,
vmetadata = EXCLUDED.vmetadata
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata": json.dumps(item["metadata"]),
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
@@ -230,16 +323,32 @@ class PgvectorClient(VectorDBBase):
.alias("query_vectors")
)
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
# Build the lateral subquery for each query vector
subq = (
select(
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@@ -299,17 +408,43 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if PGVECTOR_PGCRYPTO:
# Build where clause for vmetadata filter
where_clauses = [DocumentChunk.collection_name == collection_name]
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
if limit is not None:
query = query.limit(limit)
if limit is not None:
query = query.limit(limit)
results = query.all()
results = query.all()
if not results:
return None
@@ -331,20 +466,38 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
if PGVECTOR_PGCRYPTO:
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
ids = [[row.id for row in results]]
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
results = query.all()
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
if not results:
return None
results = query.all()
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
if not results:
return None
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
@@ -358,17 +511,33 @@ class PgvectorClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
if PGVECTOR_PGCRYPTO:
wheres = [DocumentChunk.collection_name == collection_name]
if ids:
wheres.append(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
wheres.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:

View File

@@ -3,10 +3,19 @@ import logging
import time # for measuring elapsed time
from pinecone import Pinecone, ServerlessSpec
# Add gRPC support for better performance (Pinecone best practice)
try:
from pinecone.grpc import PineconeGRPC
GRPC_AVAILABLE = True
except ImportError:
GRPC_AVAILABLE = False
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import (
VectorDBBase,
@@ -47,7 +56,24 @@ class PineconeClient(VectorDBBase):
self.cloud = PINECONE_CLOUD
# Initialize Pinecone client for improved performance
self.client = Pinecone(api_key=self.api_key)
if GRPC_AVAILABLE:
# Use gRPC client for better performance (Pinecone recommendation)
self.client = PineconeGRPC(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = True
log.info("Using Pinecone gRPC client for optimal performance")
else:
# Fallback to HTTP client with enhanced connection pooling
self.client = Pinecone(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = False
log.info("Using Pinecone HTTP client (gRPC not available)")
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -91,12 +117,53 @@ class PineconeClient(VectorDBBase):
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(self.index_name)
self.index = self.client.Index(
self.index_name,
pool_threads=20, # Enhanced connection pool for index operations
)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _retry_pinecone_operation(self, operation_func, max_retries=3):
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
for attempt in range(max_retries):
try:
return operation_func()
except Exception as e:
error_str = str(e).lower()
# Check if it's a retryable error (rate limits, network issues, timeouts)
is_retryable = any(
keyword in error_str
for keyword in [
"rate limit",
"quota",
"timeout",
"network",
"connection",
"unavailable",
"internal error",
"429",
"500",
"502",
"503",
"504",
]
)
if not is_retryable or attempt == max_retries - 1:
# Don't retry for non-retryable errors or on final attempt
raise
# Exponential backoff with jitter
delay = (2**attempt) + random.uniform(0, 1)
log.warning(
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
@@ -223,7 +290,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"Successfully inserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -254,7 +322,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"Successfully upserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -285,7 +354,8 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"Successfully async inserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -316,7 +386,8 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"Successfully async upserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
def search(
@@ -457,10 +528,12 @@ class PineconeClient(VectorDBBase):
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
f"Deleted batch of {len(batch_ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
f"Successfully deleted {len(ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
elif filter:

View File

@@ -1,10 +1,20 @@
import logging
from typing import Optional, List
from typing import Optional, Literal
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
MODELS = Literal[
"sonar",
"sonar-pro",
"sonar-reasoning",
"sonar-reasoning-pro",
"sonar-deep-research",
]
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -14,6 +24,8 @@ def search_perplexity(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
model: MODELS = "sonar",
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
@@ -21,6 +33,9 @@ def search_perplexity(
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
filter_list (Optional[list[str]]): List of domains to filter results
model (str): The Perplexity model to use (sonar, sonar-pro)
search_context_usage (str): Search context usage level (low, medium, high)
"""
@@ -33,7 +48,7 @@ def search_perplexity(
# Create payload for the API call
payload = {
"model": "sonar",
"model": model,
"messages": [
{
"role": "system",
@@ -43,6 +58,9 @@ def search_perplexity(
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
"web_search_options": {
"search_context_usage": search_context_usage,
},
}
headers = {

View File

@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
try:
if "," in b64_str:
header, encoded = b64_str.split(",", 1)
mime_type = header.split(";")[0]
mime_type = header.split(";")[0].lstrip("data:")
img_data = base64.b64decode(encoded)
else:
mime_type = "image/png"
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
return img_data, mime_type
except Exception as e:
log.exception(f"Error loading image data: {e}")
return None
return None, None
def load_url_image_data(url, headers=None):

View File

@@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
raise HTTPException(
@@ -159,9 +158,8 @@ async def update_note_by_id(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
@@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(

View File

@@ -1232,6 +1232,9 @@ class GenerateChatCompletionForm(BaseModel):
stream: Optional[bool] = True
keep_alive: Optional[Union[int, str]] = None
tools: Optional[list[dict]] = None
model_config = ConfigDict(
extra="allow",
)
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
@@ -1269,7 +1272,9 @@ async def generate_chat_completion(
detail=str(e),
)
payload = {**form_data.model_dump(exclude_none=True)}
if isinstance(form_data, BaseModel):
payload = {**form_data.model_dump(exclude_none=True)}
if "metadata" in payload:
del payload["metadata"]
@@ -1285,11 +1290,7 @@ async def generate_chat_completion(
if params:
system = params.pop("system", None)
# Unlike OpenAI, Ollama does not support params directly in the body
payload["options"] = apply_model_params_to_body_ollama(
params, (payload.get("options", {}) or {})
)
payload = apply_model_params_to_body_ollama(params, payload)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
@@ -1323,7 +1324,7 @@ async def generate_chat_completion(
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# payload["keep_alive"] = -1 # keep alive forever
return await send_post_request(
url=f"{url}/api/chat",
payload=json.dumps(payload),

View File

@@ -887,6 +887,88 @@ async def generate_chat_completion(
await session.close()
async def embeddings(request: Request, form_data: dict, user):
"""
Calls the embeddings endpoint for OpenAI-compatible providers.
Args:
request (Request): The FastAPI request context.
form_data (dict): OpenAI-compatible embeddings payload.
user (UserModel): The authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
"""
idx = 0
# Prepare payload/body
body = json.dumps(form_data)
# Find correct backend url/key based on model
await get_all_models(request, user=user)
model_id = form_data.get("model")
models = request.app.state.OPENAI_MODELS
if model_id in models:
idx = models[model_id]["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/embeddings",
data=body,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = await r.json()
if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"""

View File

@@ -414,6 +414,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -467,6 +470,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@@ -520,6 +525,8 @@ class WebConfig(BaseModel):
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
EXA_API_KEY: Optional[str] = None
PERPLEXITY_API_KEY: Optional[str] = None
PERPLEXITY_MODEL: Optional[str] = None
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
SOUGOU_API_SID: Optional[str] = None
SOUGOU_API_SK: Optional[str] = None
WEB_LOADER_ENGINE: Optional[str] = None
@@ -571,6 +578,9 @@ class ConfigForm(BaseModel):
DOCLING_OCR_ENGINE: Optional[str] = None
DOCLING_OCR_LANG: Optional[str] = None
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
MISTRAL_OCR_API_KEY: Optional[str] = None
@@ -744,6 +754,22 @@ async def update_rag_config(
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
)
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
form_data.DOCLING_PICTURE_DESCRIPTION_API
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
)
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
@@ -907,6 +933,10 @@ async def update_rag_config(
)
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
)
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
@@ -977,6 +1007,9 @@ async def update_rag_config(
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
@@ -1030,6 +1063,8 @@ async def update_rag_config(
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@@ -1321,9 +1356,14 @@ def process_file(
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
DOCLING_PARAMS={
"ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
"ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
"do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
"picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
"picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
"picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
},
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
@@ -1740,19 +1780,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "exa":
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
model=request.app.state.config.PERPLEXITY_MODEL,
search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
)
elif engine == "sougou":
if (

View File

@@ -9,6 +9,7 @@ import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
follow_up_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
ENABLE_FOLLOW_UP_GENERATION: bool
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
@@ -94,6 +100,13 @@ async def update_task_config(
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
form_data.ENABLE_FOLLOW_UP_GENERATION
)
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
@@ -133,6 +146,8 @@ async def update_task_config(
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -231,6 +246,86 @@ async def generate_title(
)
@router.post("/follow_up/completions")
async def generate_follow_ups(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Follow-up generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
content = follow_up_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.FOLLOW_UP_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)

View File

@@ -165,22 +165,6 @@ async def update_default_user_permissions(
return request.app.state.config.USER_PERMISSIONS
############################
# UpdateUserRole
############################
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
############################
# GetUserSettingsBySessionUser
############################
@@ -333,11 +317,22 @@ async def update_user_by_id(
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id and session_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
if first_user:
if user_id == first_user.id:
if session_user.id != user_id:
# If the user trying to update is the primary admin, and they are not the primary admin themselves
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
if form_data.role != "admin":
# If the primary admin is trying to change their own role, prevent it
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
@@ -365,6 +360,7 @@ async def update_user_by_id(
updated_user = Users.update_user_by_id(
user_id,
{
"role": form_data.role,
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,

View File

@@ -34,7 +34,7 @@ class CodeForm(BaseModel):
@router.post("/code/format")
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
try:
formatted_code = black.format_str(form_data.code, mode=black.Mode())
return {"code": formatted_code}

View File

@@ -2,16 +2,87 @@
import asyncio
from typing import Dict
from uuid import uuid4
import json
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
# A dictionary to keep track of active tasks
tasks: Dict[str, asyncio.Task] = {}
chat_tasks = {}
def cleanup_task(task_id: str, id=None):
REDIS_TASKS_KEY = "open-webui:tasks"
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
def is_redis(request: Request) -> bool:
# Called everywhere a request is available to check Redis
return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
async def redis_task_command_listener(app):
redis: Redis = app.state.redis
pubsub = redis.pubsub()
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
async for message in pubsub.listen():
if message["type"] != "message":
continue
try:
command = json.loads(message["data"])
if command.get("action") == "stop":
task_id = command.get("task_id")
local_task = tasks.get(task_id)
if local_task:
local_task.cancel()
except Exception as e:
print(f"Error handling distributed task command: {e}")
### ------------------------------
### REDIS-ENABLED HANDLERS
### ------------------------------
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
if chat_id:
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
await pipe.execute()
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
pipe = redis.pipeline()
pipe.hdel(REDIS_TASKS_KEY, task_id)
if chat_id:
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
await pipe.execute()
async def redis_list_tasks(redis: Redis) -> List[str]:
return list(await redis.hkeys(REDIS_TASKS_KEY))
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
async def redis_send_command(redis: Redis, command: dict):
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
async def cleanup_task(request, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
"""
if is_redis(request):
await redis_cleanup_task(request.app.state.redis, task_id, id)
tasks.pop(task_id, None) # Remove the task if it exists
# If an ID is provided, remove the task from the chat_tasks dictionary
@@ -21,7 +92,7 @@ def cleanup_task(task_id: str, id=None):
chat_tasks.pop(id, None)
def create_task(coroutine, id=None):
async def create_task(request, coroutine, id=None):
"""
Create a new asyncio task and add it to the global task dictionary.
"""
@@ -29,7 +100,9 @@ def create_task(coroutine, id=None):
task = asyncio.create_task(coroutine) # Create the task
# Add a done callback for cleanup
task.add_done_callback(lambda t: cleanup_task(task_id, id))
task.add_done_callback(
lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
)
tasks[task_id] = task
# If an ID is provided, associate the task with that ID
@@ -38,34 +111,46 @@ def create_task(coroutine, id=None):
else:
chat_tasks[id] = [task_id]
if is_redis(request):
await redis_save_task(request.app.state.redis, task_id, id)
return task_id, task
def get_task(task_id: str):
"""
Retrieve a task by its task ID.
"""
return tasks.get(task_id)
def list_tasks():
async def list_tasks(request):
"""
List all currently active task IDs.
"""
if is_redis(request):
return await redis_list_tasks(request.app.state.redis)
return list(tasks.keys())
def list_task_ids_by_chat_id(id):
async def list_task_ids_by_chat_id(request, id):
"""
List all tasks associated with a specific ID.
"""
if is_redis(request):
return await redis_list_chat_tasks(request.app.state.redis, id)
return chat_tasks.get(id, [])
async def stop_task(task_id: str):
async def stop_task(request, task_id: str):
"""
Cancel a running task and remove it from the global task list.
"""
if is_redis(request):
# PUBSUB: All instances check if they have this task, and stop if so.
await redis_send_command(
request.app.state.redis,
{
"action": "stop",
"task_id": task_id,
},
)
# Optionally check if task_id still in Redis a few moments later for feedback?
return {"status": True, "message": f"Stop signal sent for {task_id}"}
task = tasks.get(task_id)
if not task:
raise ValueError(f"Task with ID {task_id} not found.")

View File

@@ -23,6 +23,7 @@ from open_webui.env import (
TRUSTED_SIGNATURE_KEY,
STATIC_DIR,
SRC_LOG_LEVELS,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
)
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
@@ -157,6 +158,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
def get_current_user(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
@@ -225,6 +227,19 @@ def get_current_user(
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER)
if trusted_email and user.email != trusted_email:
# Delete the token cookie
response.delete_cookie("token")
# Delete OAuth token if present
if request.cookies.get("oauth_id_token"):
response.delete_cookie("oauth_id_token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User mismatch. Please sign in again.",
)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:

View File

@@ -320,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -424,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
try:
if hasattr(function_module, "UserValves"):

View File

@@ -0,0 +1,90 @@
import random
import logging
import sys
from fastapi import Request
from open_webui.models.users import UserModel
from open_webui.models.models import Models
from open_webui.utils.models import check_model_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.routers.openai import embeddings as openai_embeddings
from open_webui.routers.ollama import (
embeddings as ollama_embeddings,
GenerateEmbeddingsForm,
)
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_embeddings(
request: Request,
form_data: dict,
user: UserModel,
bypass_filter: bool = False,
):
"""
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
Args:
request (Request): The FastAPI request context.
form_data (dict): The input data sent to the endpoint.
user (UserModel): The authenticated user.
bypass_filter (bool): If True, disables access filtering (default False).
Returns:
dict: The embeddings response, following OpenAI API compatibility.
"""
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
# Attach extra metadata from request.state if present
if hasattr(request.state, "metadata"):
if "metadata" not in form_data:
form_data["metadata"] = request.state.metadata
else:
form_data["metadata"] = {
**form_data["metadata"],
**request.state.metadata,
}
# If "direct" flag present, use only that model
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data.get("model")
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# Access filtering
if not getattr(request.state, "direct", False):
if not bypass_filter and user.role == "user":
check_model_access(user, model)
# Ollama backend
if model.get("owned_by") == "ollama":
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
response = await ollama_embeddings(
request=request,
form_data=GenerateEmbeddingsForm(**ollama_payload),
user=user,
)
return convert_embedding_response_ollama_to_openai(response)
# Default: OpenAI or compatible backend
return await openai_embeddings(
request=request,
form_data=form_data,
user=user,
)

View File

@@ -32,11 +32,17 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_follow_ups,
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.images import image_generations, GenerateImageForm
from open_webui.routers.images import (
load_b64_image_data,
image_generations,
GenerateImageForm,
upload_image,
)
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
@@ -692,13 +698,8 @@ def apply_params_to_form_data(form_data, model):
params = deep_update(params, custom_params)
if model.get("ollama"):
# Ollama specific parameters
form_data["options"] = params
if "format" in params:
form_data["format"] = params["format"]
if "keep_alive" in params:
form_data["keep_alive"] = params["keep_alive"]
else:
if isinstance(params, dict):
for key, value in params.items():
@@ -726,12 +727,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -1061,6 +1057,59 @@ async def process_chat_response(
)
if tasks and messages:
if (
TASKS.FOLLOW_UP_GENERATION in tasks
and tasks[TASKS.FOLLOW_UP_GENERATION]
):
res = await generate_follow_ups(
request,
{
"model": message["model"],
"messages": messages,
"message_id": metadata["message_id"],
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
if len(res.get("choices", [])) == 1:
follow_ups_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
else:
follow_ups_string = ""
follow_ups_string = follow_ups_string[
follow_ups_string.find("{") : follow_ups_string.rfind("}")
+ 1
]
try:
follow_ups = json.loads(follow_ups_string).get(
"follow_ups", []
)
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"followUps": follow_ups,
},
)
await event_emitter(
{
"type": "chat:message:follow_ups",
"data": {
"follow_ups": follow_ups,
},
}
)
except Exception as e:
pass
if TASKS.TITLE_GENERATION in tasks:
if tasks[TASKS.TITLE_GENERATION]:
res = await generate_title(
@@ -1286,12 +1335,7 @@ async def process_chat_response(
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -1835,9 +1879,11 @@ async def process_chat_response(
value = delta.get("content")
reasoning_content = delta.get(
"reasoning_content"
) or delta.get("reasoning")
reasoning_content = (
delta.get("reasoning_content")
or delta.get("reasoning")
or delta.get("thinking")
)
if reasoning_content:
if (
not content_blocks
@@ -2230,28 +2276,21 @@ async def process_chat_response(
stdoutLines = stdout.split("\n")
for idx, line in enumerate(stdoutLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
image_url = ""
# Extract base64 image data from the line
image_data, content_type = (
load_b64_image_data(line)
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
if image_data is not None:
image_url = upload_image(
request,
image_data,
content_type,
metadata,
user,
)
stdoutLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
f"![Output Image]({image_url})"
)
output["stdout"] = "\n".join(stdoutLines)
@@ -2262,30 +2301,22 @@ async def process_chat_response(
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
image_url = ""
# Extract base64 image data from the line
image_data, content_type = (
load_b64_image_data(line)
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
if image_data is not None:
image_url = upload_image(
request,
image_data,
content_type,
metadata,
user,
)
resultLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
f"![Output Image]({image_url})"
)
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)
@@ -2394,8 +2425,8 @@ async def process_chat_response(
await response.background()
# background_tasks.add_task(post_response_handler, response, events)
task_id, _ = create_task(
post_response_handler(response, events), id=metadata["chat_id"]
task_id, _ = await create_task(
request, post_response_handler(response, events), id=metadata["chat_id"]
)
return {"status": True, "task_id": task_id}

View File

@@ -208,6 +208,7 @@ def openai_chat_message_template(model: str):
def openai_chat_chunk_message_template(
model: str,
content: Optional[str] = None,
reasoning_content: Optional[str] = None,
tool_calls: Optional[list[dict]] = None,
usage: Optional[dict] = None,
) -> dict:
@@ -220,6 +221,9 @@ def openai_chat_chunk_message_template(
if content:
template["choices"][0]["delta"]["content"] = content
if reasoning_content:
template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
if tool_calls:
template["choices"][0]["delta"]["tool_calls"] = tool_calls
@@ -234,6 +238,7 @@ def openai_chat_chunk_message_template(
def openai_chat_completion_message_template(
model: str,
message: Optional[str] = None,
reasoning_content: Optional[str] = None,
tool_calls: Optional[list[dict]] = None,
usage: Optional[dict] = None,
) -> dict:
@@ -241,8 +246,9 @@ def openai_chat_completion_message_template(
template["object"] = "chat.completion"
if message is not None:
template["choices"][0]["message"] = {
"content": message,
"role": "assistant",
"content": message,
**({"reasoning_content": reasoning_content} if reasoning_content else {}),
**({"tool_calls": tool_calls} if tool_calls else {}),
}

View File

@@ -538,7 +538,7 @@ class OAuthManager:
# Redirect back to the frontend with the JWT token
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
if redirect_base_url.endswith("/"):
if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
redirect_base_url = redirect_base_url[:-1]
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"

View File

@@ -175,16 +175,32 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
"num_thread": int,
}
# Extract keep_alive from options if it exists
if "options" in form_data and "keep_alive" in form_data["options"]:
form_data["keep_alive"] = form_data["options"]["keep_alive"]
del form_data["options"]["keep_alive"]
def parse_json(value: str) -> dict:
"""
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
"""
try:
return json.loads(value)
except Exception as e:
return value
if "options" in form_data and "format" in form_data["options"]:
form_data["format"] = form_data["options"]["format"]
del form_data["options"]["format"]
ollama_root_params = {
"format": lambda x: parse_json(x),
"keep_alive": lambda x: parse_json(x),
"think": bool,
}
return apply_model_params_to_body(params, form_data, mappings)
for key, value in ollama_root_params.items():
if (param := params.get(key, None)) is not None:
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
form_data[key] = value(param)
del params[key]
# Unlike OpenAI, Ollama does not support params directly in the body
form_data["options"] = apply_model_params_to_body(
params, (form_data.get("options", {}) or {}), mappings
)
return form_data
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
@@ -279,36 +295,48 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
openai_payload.get("messages")
)
ollama_payload["stream"] = openai_payload.get("stream", False)
if "tools" in openai_payload:
ollama_payload["tools"] = openai_payload["tools"]
if "format" in openai_payload:
ollama_payload["format"] = openai_payload["format"]
# If there are advanced parameters in the payload, format them in Ollama's options field
if openai_payload.get("options"):
ollama_payload["options"] = openai_payload["options"]
ollama_options = openai_payload["options"]
def parse_json(value: str) -> dict:
"""
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
"""
try:
return json.loads(value)
except Exception as e:
return value
ollama_root_params = {
"format": lambda x: parse_json(x),
"keep_alive": lambda x: parse_json(x),
"think": bool,
}
# Ollama's options field can contain parameters that should be at the root level.
for key, value in ollama_root_params.items():
if (param := ollama_options.get(key, None)) is not None:
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
ollama_payload[key] = value(param)
del ollama_options[key]
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
if "max_tokens" in ollama_options:
ollama_options["num_predict"] = ollama_options["max_tokens"]
del ollama_options[
"max_tokens"
] # To prevent Ollama warning of invalid option provided
del ollama_options["max_tokens"]
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
# Comment: Not sure why this is needed, but we'll keep it for compatibility.
if "system" in ollama_options:
ollama_payload["system"] = ollama_options["system"]
del ollama_options[
"system"
] # To prevent Ollama warning of invalid option provided
del ollama_options["system"]
# Extract keep_alive from options if it exists
if "keep_alive" in ollama_options:
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
del ollama_options["keep_alive"]
ollama_payload["options"] = ollama_options
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
if "stop" in openai_payload:
@@ -329,3 +357,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
ollama_payload["format"] = format
return ollama_payload
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
"""
Convert an embeddings request payload from OpenAI format to Ollama format.
Args:
openai_payload (dict): The original payload designed for OpenAI API usage.
Returns:
dict: A payload compatible with the Ollama API embeddings endpoint.
"""
ollama_payload = {"model": openai_payload.get("model")}
input_value = openai_payload.get("input")
# Ollama expects 'input' as a list, and 'prompt' as a single string.
if isinstance(input_value, list):
ollama_payload["input"] = input_value
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
else:
ollama_payload["input"] = [input_value]
ollama_payload["prompt"] = str(input_value)
# Optionally forward other fields if present
for optional_key in ("options", "truncate", "keep_alive"):
if optional_key in openai_payload:
ollama_payload[optional_key] = openai_payload[optional_key]
return ollama_payload

View File

@@ -1,7 +1,6 @@
import socketio
import redis
from redis import asyncio as aioredis
from urllib.parse import urlparse
from typing import Optional
def parse_redis_service_url(redis_url):
@@ -18,23 +17,46 @@ def parse_redis_service_url(redis_url):
}
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
def get_redis_connection(
redis_url, redis_sentinels, async_mode=False, decode_responses=True
):
if async_mode:
import redis.asyncio as redis
# Get a master connection from Sentinel
return sentinel.master_for(redis_config["service"])
# If using sentinel in async mode
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
return sentinel.master_for(redis_config["service"])
elif redis_url:
return redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
else:
# Standard Redis connection
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
import redis
if redis_sentinels:
redis_config = parse_redis_service_url(redis_url)
sentinel = redis.sentinel.Sentinel(
redis_sentinels,
port=redis_config["port"],
db=redis_config["db"],
username=redis_config["username"],
password=redis_config["password"],
decode_responses=decode_responses,
)
return sentinel.master_for(redis_config["service"])
elif redis_url:
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
else:
return None
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):

View File

@@ -83,6 +83,7 @@ def convert_ollama_usage_to_openai(data: dict) -> dict:
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
model = ollama_response.get("model", "ollama")
message_content = ollama_response.get("message", {}).get("content", "")
reasoning_content = ollama_response.get("message", {}).get("thinking", None)
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
@@ -94,7 +95,7 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
usage = convert_ollama_usage_to_openai(data)
response = openai_chat_completion_message_template(
model, message_content, openai_tool_calls, usage
model, message_content, reasoning_content, openai_tool_calls, usage
)
return response
@@ -105,6 +106,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
model = data.get("model", "ollama")
message_content = data.get("message", {}).get("content", None)
reasoning_content = data.get("message", {}).get("thinking", None)
tool_calls = data.get("message", {}).get("tool_calls", None)
openai_tool_calls = None
@@ -118,10 +120,71 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
usage = convert_ollama_usage_to_openai(data)
data = openai_chat_chunk_message_template(
model, message_content, openai_tool_calls, usage
model, message_content, reasoning_content, openai_tool_calls, usage
)
line = f"data: {json.dumps(data)}\n\n"
yield line
yield "data: [DONE]\n\n"
def convert_embedding_response_ollama_to_openai(response) -> dict:
"""
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
Args:
response (dict): The response from the Ollama API,
e.g. {"embedding": [...], "model": "..."}
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
Returns:
dict: Response adapted to OpenAI's embeddings API format.
e.g. {
"object": "list",
"data": [
{"object": "embedding", "embedding": [...], "index": 0},
...
],
"model": "...",
}
"""
# Ollama batch-style output
if isinstance(response, dict) and "embeddings" in response:
openai_data = []
for i, emb in enumerate(response["embeddings"]):
openai_data.append(
{
"object": "embedding",
"embedding": emb.get("embedding"),
"index": emb.get("index", i),
}
)
return {
"object": "list",
"data": openai_data,
"model": response.get("model"),
}
# Ollama single output
elif isinstance(response, dict) and "embedding" in response:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": response["embedding"],
"index": 0,
}
],
"model": response.get("model"),
}
# Already OpenAI-compatible?
elif (
isinstance(response, dict)
and "data" in response
and isinstance(response["data"], list)
):
return response
# Fallback: return as is if unrecognized
return response

View File

@@ -207,6 +207,24 @@ def title_generation_template(
return template
def follow_up_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:

View File

@@ -1,5 +1,5 @@
fastapi==0.115.7
uvicorn[standard]==0.34.0
uvicorn[standard]==0.34.2
pydantic==2.10.6
python-multipart==0.0.20
@@ -7,14 +7,13 @@ python-socketio==5.13.0
python-jose==3.4.0
passlib[bcrypt]==1.7.4
requests==2.32.3
requests==2.32.4
aiohttp==3.11.11
async-timeout
aiocache
aiofiles
starlette-compress==1.6.0
sqlalchemy==2.0.38
alembic==1.14.0
peewee==3.18.1
@@ -76,13 +75,13 @@ pandas==2.2.3
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.34.0
validators==0.35.0
psutil
sentencepiece
soundfile==0.13.1
azure-ai-documentintelligence==1.0.0
azure-ai-documentintelligence==1.0.2
pillow==11.1.0
pillow==11.2.1
opencv-python-headless==4.11.0.86
rapidocr-onnxruntime==1.4.4
rank-bm25==0.2.2

View File

@@ -14,7 +14,11 @@ if [[ "${WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
python -c "import nltk; nltk.download('punkt_tab')"
fi
KEY_FILE=.webui_secret_key
if [ -n "${WEBUI_SECRET_KEY_FILE}" ]; then
KEY_FILE="${WEBUI_SECRET_KEY_FILE}"
else
KEY_FILE=".webui_secret_key"
fi
PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"

View File

@@ -18,6 +18,10 @@ IF /I "%WEB_LOADER_ENGINE%" == "playwright" (
)
SET "KEY_FILE=.webui_secret_key"
IF NOT "%WEBUI_SECRET_KEY_FILE%" == "" (
SET "KEY_FILE=%WEBUI_SECRET_KEY_FILE%"
)
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"