Merge branch 'dev' of https://github.com/open-webui/open-webui into Dev-Individual-RAG-Config

This commit is contained in:
weberm1
2025-06-06 12:02:33 +02:00
205 changed files with 11206 additions and 4089 deletions

View File

@@ -901,9 +901,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig(
####################################
WEBUI_URL = PersistentConfig(
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
)
WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", ""))
ENABLE_SIGNUP = PersistentConfig(
@@ -1413,6 +1411,35 @@ Strictly return in JSON format:
{{MESSAGES:END:6}}
</chat_history>"""
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
"task.follow_up.prompt_template",
os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
)
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
### Guidelines:
- Write all follow-up questions from the users point of view, directed to the assistant.
- Make questions concise, clear, and directly related to the discussed topic(s).
- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
- Use the conversation's primary language; default to English if multilingual.
- Response must be a JSON array of strings, no extra text or formatting.
### Output:
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
### Chat History:
<chat_history>
{{MESSAGES:END:6}}
</chat_history>"""
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
"ENABLE_FOLLOW_UP_GENERATION",
"task.follow_up.enable",
os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
)
ENABLE_TAGS_GENERATION = PersistentConfig(
"ENABLE_TAGS_GENERATION",
"task.tags.enable",
@@ -1848,6 +1875,61 @@ CONTENT_EXTRACTION_ENGINE = PersistentConfig(
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
DATALAB_MARKER_API_KEY = PersistentConfig(
"DATALAB_MARKER_API_KEY",
"rag.datalab_marker_api_key",
os.environ.get("DATALAB_MARKER_API_KEY", ""),
)
DATALAB_MARKER_LANGS = PersistentConfig(
"DATALAB_MARKER_LANGS",
"rag.datalab_marker_langs",
os.environ.get("DATALAB_MARKER_LANGS", ""),
)
DATALAB_MARKER_USE_LLM = PersistentConfig(
"DATALAB_MARKER_USE_LLM",
"rag.DATALAB_MARKER_USE_LLM",
os.environ.get("DATALAB_MARKER_USE_LLM", "false").lower() == "true",
)
DATALAB_MARKER_SKIP_CACHE = PersistentConfig(
"DATALAB_MARKER_SKIP_CACHE",
"rag.datalab_marker_skip_cache",
os.environ.get("DATALAB_MARKER_SKIP_CACHE", "false").lower() == "true",
)
DATALAB_MARKER_FORCE_OCR = PersistentConfig(
"DATALAB_MARKER_FORCE_OCR",
"rag.datalab_marker_force_ocr",
os.environ.get("DATALAB_MARKER_FORCE_OCR", "false").lower() == "true",
)
DATALAB_MARKER_PAGINATE = PersistentConfig(
"DATALAB_MARKER_PAGINATE",
"rag.datalab_marker_paginate",
os.environ.get("DATALAB_MARKER_PAGINATE", "false").lower() == "true",
)
DATALAB_MARKER_STRIP_EXISTING_OCR = PersistentConfig(
"DATALAB_MARKER_STRIP_EXISTING_OCR",
"rag.datalab_marker_strip_existing_ocr",
os.environ.get("DATALAB_MARKER_STRIP_EXISTING_OCR", "false").lower() == "true",
)
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION",
"rag.datalab_marker_disable_image_extraction",
os.environ.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", "false").lower()
== "true",
)
DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig(
"DATALAB_MARKER_OUTPUT_FORMAT",
"rag.datalab_marker_output_format",
os.environ.get("DATALAB_MARKER_OUTPUT_FORMAT", "markdown"),
)
EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig(
"EXTERNAL_DOCUMENT_LOADER_URL",
"rag.external_document_loader_url",
@@ -1928,6 +2010,11 @@ RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"rag.relevance_threshold",
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
)
RAG_HYBRID_BM25_WEIGHT = PersistentConfig(
"RAG_HYBRID_BM25_WEIGHT",
"rag.hybrid_bm25_weight",
float(os.environ.get("RAG_HYBRID_BM25_WEIGHT", "0.5")),
)
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
"ENABLE_RAG_HYBRID_SEARCH",
@@ -2124,6 +2211,22 @@ RAG_OPENAI_API_KEY = PersistentConfig(
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
RAG_AZURE_OPENAI_BASE_URL = PersistentConfig(
"RAG_AZURE_OPENAI_BASE_URL",
"rag.azure_openai.base_url",
os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""),
)
RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
"RAG_AZURE_OPENAI_API_KEY",
"rag.azure_openai.api_key",
os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
)
RAG_AZURE_OPENAI_API_VERSION = PersistentConfig(
"RAG_AZURE_OPENAI_API_VERSION",
"rag.azure_openai.api_version",
os.getenv("RAG_AZURE_OPENAI_API_VERSION", ""),
)
RAG_OLLAMA_BASE_URL = PersistentConfig(
"RAG_OLLAMA_BASE_URL",
"rag.ollama.url",
@@ -2213,6 +2316,12 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
)
BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig(
"BYPASS_WEB_SEARCH_WEB_LOADER",
"rag.web.search.bypass_web_loader",
os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true",
)
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
@@ -2238,6 +2347,7 @@ WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
WEB_LOADER_ENGINE = PersistentConfig(
"WEB_LOADER_ENGINE",
"rag.web.loader.engine",
@@ -2397,6 +2507,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
os.getenv("PERPLEXITY_API_KEY", ""),
)
PERPLEXITY_MODEL = PersistentConfig(
"PERPLEXITY_MODEL",
"rag.web.search.perplexity_model",
os.getenv("PERPLEXITY_MODEL", "sonar"),
)
PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
"PERPLEXITY_SEARCH_CONTEXT_USAGE",
"rag.web.search.perplexity_search_context_usage",
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
)
SOUGOU_API_SID = PersistentConfig(
"SOUGOU_API_SID",
"rag.web.search.sougou_api_sid",

View File

@@ -111,6 +111,7 @@ class TASKS(str, Enum):
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"

View File

@@ -349,6 +349,10 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
)
BYPASS_MODEL_ACCESS_CONTROL = (
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"

View File

@@ -25,10 +25,14 @@ from open_webui.socket.main import (
)
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.access_control import has_access
@@ -53,9 +57,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
# Check if function is already loaded
function_module, _, _ = load_function_module_by_id(pipe_id)
request.app.state.FUNCTIONS[pipe_id] = function_module
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
@@ -226,12 +228,7 @@ async def generate_function_chat_completion(
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
}
@@ -252,8 +249,13 @@ async def generate_function_chat_completion(
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
if params:
system = params.pop("system", None)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(
system, form_data, metadata, user
)
pipe_id = get_pipe_id(form_data)
function_module = get_function_module_by_id(request, pipe_id)

View File

@@ -37,9 +37,11 @@ from fastapi import (
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from starlette_compress import CompressMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
@@ -196,17 +198,32 @@ from open_webui.config import (
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_TOP_K,
RAG_TOP_K_RERANKER,
RAG_RELEVANCE_THRESHOLD,
RAG_HYBRID_BM25_WEIGHT,
RAG_ALLOWED_FILE_EXTENSIONS,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_AZURE_OPENAI_BASE_URL,
RAG_AZURE_OPENAI_API_KEY,
RAG_AZURE_OPENAI_API_VERSION,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
DATALAB_MARKER_API_KEY,
DATALAB_MARKER_LANGS,
DATALAB_MARKER_SKIP_CACHE,
DATALAB_MARKER_FORCE_OCR,
DATALAB_MARKER_PAGINATE,
DATALAB_MARKER_STRIP_EXISTING_OCR,
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
DATALAB_MARKER_OUTPUT_FORMAT,
DATALAB_MARKER_USE_LLM,
EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL,
@@ -217,8 +234,6 @@ from open_webui.config import (
DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY,
MISTRAL_OCR_API_KEY,
RAG_TOP_K,
RAG_TOP_K_RERANKER,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES,
@@ -233,6 +248,7 @@ from open_webui.config import (
ENABLE_WEB_SEARCH,
WEB_SEARCH_ENGINE,
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
BYPASS_WEB_SEARCH_WEB_LOADER,
WEB_SEARCH_RESULT_COUNT,
WEB_SEARCH_CONCURRENT_REQUESTS,
WEB_SEARCH_TRUST_ENV,
@@ -257,6 +273,8 @@ from open_webui.config import (
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
PERPLEXITY_API_KEY,
PERPLEXITY_MODEL,
PERPLEXITY_SEARCH_CONTEXT_USAGE,
SOUGOU_API_SID,
SOUGOU_API_SK,
KAGI_SEARCH_API_KEY,
@@ -348,10 +366,12 @@ from open_webui.config import (
TASK_MODEL_EXTERNAL,
ENABLE_TAGS_GENERATION,
ENABLE_TITLE_GENERATION,
ENABLE_FOLLOW_UP_GENERATION,
ENABLE_SEARCH_QUERY_GENERATION,
ENABLE_RETRIEVAL_QUERY_GENERATION,
ENABLE_AUTOCOMPLETE_GENERATION,
TITLE_GENERATION_PROMPT_TEMPLATE,
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -400,6 +420,7 @@ from open_webui.utils.chat import (
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
from open_webui.utils.embeddings import generate_embeddings
from open_webui.utils.middleware import process_chat_payload, process_chat_response
from open_webui.utils.access_control import has_access
@@ -638,8 +659,12 @@ app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL
app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
app.state.USER_COUNT = None
app.state.TOOLS = {}
app.state.TOOL_CONTENTS = {}
app.state.FUNCTIONS = {}
app.state.FUNCTION_CONTENTS = {}
########################################
#
@@ -651,6 +676,7 @@ app.state.FUNCTIONS = {}
app.state.config.TOP_K = RAG_TOP_K
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.HYBRID_BM25_WEIGHT = RAG_HYBRID_BM25_WEIGHT
app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
@@ -662,6 +688,17 @@ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
app.state.config.DATALAB_MARKER_LANGS = DATALAB_MARKER_LANGS
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
)
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
@@ -693,6 +730,10 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
app.state.config.RAG_AZURE_OPENAI_API_VERSION = RAG_AZURE_OPENAI_API_VERSION
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
@@ -712,6 +753,7 @@ app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
@@ -739,6 +781,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
@@ -789,9 +833,13 @@ try:
else app.state.config.RAG_OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
app.state.config.RAG_AZURE_OPENAI_API_VERSION
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
# Load all reranking models that are currently in use
# Load all reranking models that are currently in use
for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items():
for model in model_list:
app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf(
@@ -800,11 +848,10 @@ try:
model["RAG_EXTERNAL_RERANKER_URL"],
model["RAG_EXTERNAL_RERANKER_API_KEY"],
)
except Exception as e:
log.error(f"Error updating models: {e}")
pass
########################################
#
# CODE EXECUTION
@@ -923,6 +970,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@@ -930,6 +978,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
@@ -973,6 +1024,7 @@ class RedirectMiddleware(BaseHTTPMiddleware):
# Add the middleware to the app
app.add_middleware(CompressMiddleware)
app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
@@ -1160,6 +1212,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
return {"data": models}
##################################
# Embeddings
##################################
@app.post("/api/embeddings")
async def embeddings(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
"""
OpenAI-compatible embeddings endpoint.
This handler:
- Performs user/model checks and dispatches to the correct backend.
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
Args:
request (Request): Request context.
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
user (UserModel): Authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
"""
# Make sure models are loaded in app state
if not request.app.state.MODELS:
await get_all_models(request, user=user)
# Use generic dispatcher in utils.embeddings
return await generate_embeddings(request, form_data, user)
@app.post("/api/chat/completions")
async def chat_completion(
request: Request,
@@ -1591,7 +1674,20 @@ async def healthcheck_with_db():
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
@app.get("/cache/{path:path}")
async def serve_cache_file(
path: str,
user=Depends(get_verified_user),
):
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
# prevent path traversal
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
raise HTTPException(status_code=404, detail="File not found")
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path)
def swagger_ui_html(*args, **kwargs):

View File

@@ -129,12 +129,16 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
user = Users.get_user_by_email(email)
if not user:
return None
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
return user
else:
return None
@@ -155,8 +159,8 @@ class AuthsTable:
except Exception:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}")
try:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()

View File

@@ -377,22 +377,47 @@ class ChatTable:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
self,
user_id: str,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
all_chats = query.all()
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
user_id: str,
include_archived: bool = False,
filter: Optional[dict] = None,
skip: int = 0,
limit: int = 50,
) -> list[ChatModel]:
@@ -401,7 +426,23 @@ class ChatTable:
if not include_archived:
query = query.filter_by(archived=False)
query = query.order_by(Chat.updated_at.desc())
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by and direction and getattr(Chat, order_by):
if direction.lower() == "asc":
query = query.order_by(getattr(Chat, order_by).asc())
elif direction.lower() == "desc":
query = query.order_by(getattr(Chat, order_by).desc())
else:
raise ValueError("Invalid direction for ordering")
else:
query = query.order_by(Chat.updated_at.desc())
if skip:
query = query.offset(skip)
@@ -542,7 +583,9 @@ class ChatTable:
search_text = search_text.lower().strip()
if not search_text:
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
return self.get_chat_list_by_user_id(
user_id, include_archived, filter={}, skip=skip, limit=limit
)
search_text_words = search_text.split(" ")

View File

@@ -108,6 +108,54 @@ class FunctionsTable:
log.exception(f"Error creating a new function: {e}")
return None
def sync_functions(
self, user_id: str, functions: list[FunctionModel]
) -> list[FunctionModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try:
with get_db() as db:
# Get existing functions
existing_functions = db.query(Function).all()
existing_ids = {func.id for func in existing_functions}
# Prepare a set of new function IDs
new_function_ids = {func.id for func in functions}
# Update or insert functions
for func in functions:
if func.id in existing_ids:
db.query(Function).filter_by(id=func.id).update(
{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
else:
new_func = Function(
**{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
}
)
db.add(new_func)
# Remove functions that are no longer present
for func in existing_functions:
if func.id not in new_function_ids:
db.delete(func)
db.commit()
return [
FunctionModel.model_validate(func)
for func in db.query(Function).all()
]
except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}")
return []
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
with get_db() as db:

View File

@@ -207,5 +207,43 @@ class GroupTable:
except Exception:
return False
def sync_user_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> bool:
with get_db() as db:
try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
# Add user to new groups
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception as e:
log.exception(e)
return False
Groups = GroupTable()

View File

@@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel):
role: str
name: str
email: str
profile_image_url: str

View File

@@ -0,0 +1,251 @@
import os
import time
import requests
import logging
import json
from typing import List, Optional
from langchain_core.documents import Document
from fastapi import HTTPException, status
log = logging.getLogger(__name__)
class DatalabMarkerLoader:
def __init__(
self,
file_path: str,
api_key: str,
langs: Optional[str] = None,
use_llm: bool = False,
skip_cache: bool = False,
force_ocr: bool = False,
paginate: bool = False,
strip_existing_ocr: bool = False,
disable_image_extraction: bool = False,
output_format: str = None,
):
self.file_path = file_path
self.api_key = api_key
self.langs = langs
self.use_llm = use_llm
self.skip_cache = skip_cache
self.force_ocr = force_ocr
self.paginate = paginate
self.strip_existing_ocr = strip_existing_ocr
self.disable_image_extraction = disable_image_extraction
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower()
mime_map = {
"pdf": "application/pdf",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"odt": "application/vnd.oasis.opendocument.text",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"odp": "application/vnd.oasis.opendocument.presentation",
"html": "text/html",
"epub": "application/epub+zip",
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
"gif": "image/gif",
"tiff": "image/tiff",
}
return mime_map.get(ext, "application/octet-stream")
def check_marker_request_status(self, request_id: str) -> dict:
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
headers = {"X-Api-Key": self.api_key}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}")
return result
except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}",
)
except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
def load(self) -> List[Document]:
url = "https://www.datalab.to/api/v1/marker"
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
form_data = {
"langs": self.langs,
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"output_format": self.output_format,
}
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
try:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
response = requests.post(
url, data=form_data, files=files, headers=headers
)
response.raise_for_status()
result = response.json()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}",
)
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"):
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
)
check_url = result.get("request_check_url")
request_id = result.get("request_id")
if not check_url:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
)
for _ in range(300): # Up to 10 minutes
time.sleep(2)
try:
poll_response = requests.get(check_url, headers=headers)
poll_response.raise_for_status()
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status")
success_val = poll_result.get("success")
if status_val == "complete":
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
)
else:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
)
content_key = self.output_format.lower()
raw_content = poll_result.get(content_key)
if content_key == "json":
full_text = json.dumps(raw_content, indent=2)
elif content_key in {"markdown", "html"}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}",
)
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Datalab Marker returned empty content",
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(content_key, "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
output_path = os.path.join(marker_output_dir, output_filename)
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(full_text)
log.info(f"Saved Marker output to: {output_path}")
except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}")
metadata = {
"source": filename,
"output_format": poll_result.get("output_format", self.output_format),
"page_count": poll_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
}
images = poll_result.get("images", {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))
for k, v in metadata.items():
if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v)
elif v is None:
metadata[k] = ""
return [Document(page_content=full_text, metadata=metadata)]

View File

@@ -21,9 +21,11 @@ from langchain_community.document_loaders import (
)
from langchain_core.documents import Document
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
@@ -74,7 +76,6 @@ known_source_ext = [
"swift",
"vue",
"svelte",
"msg",
"ex",
"exs",
"erl",
@@ -145,15 +146,12 @@ class DoclingLoader:
)
}
params = {
"image_export_mode": "placeholder",
"table_mode": "accurate",
}
params = {"image_export_mode": "placeholder", "table_mode": "accurate"}
if self.params:
if self.params.get("do_picture_classification"):
params["do_picture_classification"] = self.params.get(
"do_picture_classification"
if self.params.get("do_picture_description"):
params["do_picture_description"] = self.params.get(
"do_picture_description"
)
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
@@ -236,6 +234,49 @@ class Loader:
mime_type=file_content_type,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
)
elif (
self.engine == "datalab_marker"
and self.kwargs.get("DATALAB_MARKER_API_KEY")
and file_ext
in [
"pdf",
"xls",
"xlsx",
"ods",
"doc",
"docx",
"odt",
"ppt",
"pptx",
"odp",
"html",
"epub",
"png",
"jpeg",
"jpg",
"webp",
"gif",
"tiff",
]
):
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
strip_existing_ocr=self.kwargs.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
),
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
@@ -247,7 +288,7 @@ class Loader:
params={
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
"do_picture_classification": self.kwargs.get(
"do_picture_description": self.kwargs.get(
"DOCLING_DO_PICTURE_DESCRIPTION"
),
},

View File

@@ -1,8 +1,12 @@
import requests
import aiohttp
import asyncio
import logging
import os
import sys
import time
from typing import List, Dict, Any
from contextlib import asynccontextmanager
from langchain_core.documents import Document
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
@@ -14,18 +18,37 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
class MistralLoader:
"""
Enhanced Mistral OCR loader with both sync and async support.
Loads documents by processing them through the Mistral OCR API.
Performance Optimizations:
- Differentiated timeouts for different operations
- Intelligent retry logic with exponential backoff
- Memory-efficient file streaming for large files
- Connection pooling and keepalive optimization
- Semaphore-based concurrency control for batch processing
- Enhanced error handling with retryable error classification
"""
BASE_API_URL = "https://api.mistral.ai/v1"
def __init__(self, api_key: str, file_path: str):
def __init__(
self,
api_key: str,
file_path: str,
timeout: int = 300, # 5 minutes default
max_retries: int = 3,
enable_debug_logging: bool = False,
):
"""
Initializes the loader.
Initializes the loader with enhanced features.
Args:
api_key: Your Mistral API key.
file_path: The local path to the PDF file to process.
timeout: Request timeout in seconds.
max_retries: Maximum number of retry attempts.
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
@@ -34,7 +57,46 @@ class MistralLoader:
self.api_key = api_key
self.file_path = file_path
self.headers = {"Authorization": f"Bearer {self.api_key}"}
self.timeout = timeout
self.max_retries = max_retries
self.debug = enable_debug_logging
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
self.file_name = os.path.basename(file_path)
self.file_size = os.path.getsize(file_path)
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
}
def _debug_log(self, message: str, *args) -> None:
"""
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
Only processes debug messages when debug mode is enabled, avoiding
string formatting overhead in production environments.
"""
if self.debug:
log.debug(message, *args)
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
"""Checks response status and returns JSON content."""
@@ -54,24 +116,154 @@ class MistralLoader:
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
raise # Re-raise after logging
def _upload_file(self) -> str:
"""Uploads the file to Mistral for OCR processing."""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
file_name = os.path.basename(self.file_path)
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
with open(self.file_path, "rb") as f:
files = {"file": (file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
response.raise_for_status()
upload_headers = self.headers.copy() # Avoid modifying self.headers
response = requests.post(
url, headers=upload_headers, files=files, data=data
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
response_data = self._handle_response(response)
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
raise
def _is_retryable_error(self, error: Exception) -> bool:
"""
ENHANCEMENT: Intelligent error classification for retry logic.
Determines if an error is retryable based on its type and status code.
This prevents wasting time retrying errors that will never succeed
(like authentication errors) while ensuring transient errors are retried.
Retryable errors:
- Network connection errors (temporary network issues)
- Timeouts (server might be temporarily overloaded)
- Server errors (5xx status codes - server-side issues)
- Rate limiting (429 status - temporary throttling)
Non-retryable errors:
- Authentication errors (401, 403 - won't fix with retry)
- Bad request errors (400 - malformed request)
- Not found errors (404 - resource doesn't exist)
"""
if isinstance(error, requests.exceptions.ConnectionError):
return True # Network issues are usually temporary
if isinstance(error, requests.exceptions.Timeout):
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
return False # All other errors are non-retryable
def _retry_request_sync(self, request_func, *args, **kwargs):
"""
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
Uses exponential backoff with jitter to avoid thundering herd problems.
The wait time increases exponentially but is capped at 30 seconds to
prevent excessive delays. Only retries errors that are likely to succeed
on subsequent attempts.
"""
for attempt in range(self.max_retries):
try:
return request_func(*args, **kwargs)
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
time.sleep(wait_time)
async def _retry_request_async(self, request_func, *args, **kwargs):
"""
ENHANCEMENT: Async retry logic with intelligent error classification.
Async version of retry logic that doesn't block the event loop during
wait periods. Uses the same exponential backoff strategy as sync version.
"""
for attempt in range(self.max_retries):
try:
return await request_func(*args, **kwargs)
except Exception as e:
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
raise
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time) # Non-blocking wait
def _upload_file(self) -> str:
"""
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
Uploads the file to Mistral for OCR processing (sync version).
Uses context manager for file handling to ensure proper resource cleanup.
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
log.info("Uploading file to Mistral API")
url = f"{self.BASE_API_URL}/files"
def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
files = {"file": (self.file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
response = requests.post(
url,
headers=self.headers,
files=files,
data=data,
timeout=self.upload_timeout, # Use specialized upload timeout
stream=False, # Keep as False for this endpoint
)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
@@ -81,16 +273,66 @@ class MistralLoader:
log.error(f"Failed to upload file: {e}")
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.BASE_API_URL}/files"
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
async with session.post(
url,
data=writer,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file."""
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
return self._handle_response(response)
try:
response = requests.get(url, headers=signed_url_headers, params=params)
response_data = self._handle_response(response)
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
@@ -100,8 +342,36 @@ class MistralLoader:
log.error(f"Failed to get signed URL: {e}")
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval."""
url = f"{self.BASE_API_URL}/files/{file_id}/url"
params = {"expiry": 1}
headers = {**self.headers, "Accept": "application/json"}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
async with session.get(
url,
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
) as response:
return await self._handle_response_async(response)
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
if not signed_url:
raise ValueError("Signed URL not found in response.")
self._debug_log("Signed URL received successfully")
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing."""
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.BASE_API_URL}/ocr"
ocr_headers = {
@@ -118,43 +388,217 @@ class MistralLoader:
"include_image_base64": False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
return self._handle_response(response)
try:
response = requests.post(url, headers=ocr_headers, json=payload)
ocr_response = self._handle_response(response)
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
log.debug("OCR response: %s", ocr_response)
self._debug_log("OCR response: %s", ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.BASE_API_URL}/ocr"
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
},
"include_image_base64": False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
start_time = time.time()
async with session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
) as response:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
return ocr_response
return await self._retry_request_async(ocr_request)
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage."""
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.BASE_API_URL}/files/{file_id}"
# No specific Accept header needed, default or Authorization is usually sufficient
try:
response = requests.delete(url, headers=self.headers)
delete_response = self._handle_response(
response
) # Check status, ignore response body unless needed
log.info(
f"File deleted successfully: {delete_response}"
) # Log the response if available
response = requests.delete(
url, headers=self.headers, timeout=self.cleanup_timeout
)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
# Depending on requirements, you might choose to raise the error here
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
async with session.delete(
url=f"{self.BASE_API_URL}/files/{file_id}",
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=self.cleanup_timeout
), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
@asynccontextmanager
async def _get_session(self):
"""Context manager for HTTP session with optimized settings."""
connector = aiohttp.TCPConnector(
limit=20, # Increased total connection limit for better throughput
limit_per_host=10, # Increased per-host limit for API endpoints
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
use_dns_cache=True,
keepalive_timeout=60, # Increased keepalive for connection reuse
enable_cleanup_closed=True,
force_close=False, # Allow connection reuse
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
)
timeout = aiohttp.ClientTimeout(
total=self.timeout,
connect=30, # Connection timeout
sock_read=60, # Socket read timeout
)
async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
raise_for_status=False, # We handle status codes manually
) as session:
yield session
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [
Document(
page_content="No text content found",
metadata={"error": "no_pages", "file_name": self.file_name},
)
]
documents = []
total_pages = len(pages_data)
skipped_pages = 0
# Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is None or page_index is None:
skipped_pages += 1
self._debug_log(
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
)
continue
# Clean up content efficiently with early exit for empty content
if isinstance(page_content, str):
cleaned_content = page_content.strip()
else:
cleaned_content = str(page_content).strip()
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
continue
# Create document with optimized metadata
documents.append(
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
},
)
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No valid text content found in document",
metadata={
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
},
)
]
return documents
def load(self) -> List[Document]:
"""
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
Synchronous version for backward compatibility.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
# 1. Upload file
file_id = self._upload_file()
@@ -166,53 +610,30 @@ class MistralLoader:
ocr_response = self._process_ocr(signed_url)
# 4. Process results
pages_data = ocr_response.get("pages")
if not pages_data:
log.warning("No pages found in OCR response.")
return [Document(page_content="No text content found", metadata={})]
documents = self._process_results(ocr_response)
documents = []
total_pages = len(pages_data)
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
if page_content is not None and page_index is not None:
documents.append(
Document(
page_content=page_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index
+ 1, # 1-based label for convenience
"total_pages": total_pages,
# Add other relevant metadata from page_data if available/needed
# e.g., page_data.get('width'), page_data.get('height')
},
)
)
else:
log.warning(
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
)
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
return [
Document(
page_content="No text content found in valid pages", metadata={}
)
]
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
log.error(f"An error occurred during the loading process: {e}")
# Return an empty list or a specific error document on failure
return [Document(page_content=f"Error during processing: {e}", metadata={})]
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Delete file (attempt even if prior steps failed after upload)
if file_id:
@@ -223,3 +644,124 @@ class MistralLoader:
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
async def load_async(self) -> List[Document]:
"""
Asynchronous OCR workflow execution with optimized performance.
Returns:
A list of Document objects, one for each page processed.
"""
file_id = None
start_time = time.time()
try:
async with self._get_session() as session:
# 1. Upload file with streaming
file_id = await self._upload_file_async(session)
# 2. Get signed URL
signed_url = await self._get_signed_url_async(session, file_id)
# 3. Process OCR
ocr_response = await self._process_ocr_async(session, signed_url)
# 4. Process results
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
return [
Document(
page_content=f"Error during OCR processing: {e}",
metadata={
"error": "processing_failed",
"file_name": self.file_name,
},
)
]
finally:
# 5. Cleanup - always attempt file deletion
if file_id:
try:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
Process multiple files concurrently with controlled concurrency.
Args:
loaders: List of MistralLoader instances
max_concurrent: Maximum number of concurrent requests
Returns:
List of document lists, one for each loader
"""
if not loaders:
return []
log.info(
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
start_time = time.time()
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async with semaphore:
return await loader.load_async()
# Process all files with controlled concurrency
tasks = [process_with_semaphore(loader) for loader in loaders]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Handle any exceptions in results
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
metadata={
"error": "batch_processing_failed",
"file_index": i,
},
)
]
)
else:
processed_results.append(result)
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
failure_count = len(results) - success_count
log.info(
f"Batch processing completed in {total_time:.2f}s: "
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
)
return processed_results

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple
class BaseReranker(ABC):
@abstractmethod
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
pass

View File

@@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ColBERT:
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -3,12 +3,14 @@ import requests
from typing import Optional, List, Tuple
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
class ExternalReranker:
class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,

View File

@@ -5,6 +5,7 @@ from typing import Optional, Union
import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
from huggingface_hub import snapshot_download
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
@@ -116,6 +117,7 @@ def query_doc_with_hybrid_search(
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
) -> dict:
try:
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
@@ -131,9 +133,20 @@ def query_doc_with_hybrid_search(
top_k=k,
)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
)
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.0]
)
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.0]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
)
compressor = RerankCompressor(
embedding_function=embedding_function,
top_n=k_reranker,
@@ -313,6 +326,7 @@ def query_collection_with_hybrid_search(
reranking_function,
k_reranker: int,
r: float,
hybrid_bm25_weight: float,
) -> dict:
results = []
error = False
@@ -346,6 +360,7 @@ def query_collection_with_hybrid_search(
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
)
return result, None
except Exception as e:
@@ -386,12 +401,13 @@ def get_embedding_function(
url,
key,
embedding_batch_size,
azure_api_version=None,
):
if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
elif embedding_engine in ["ollama", "openai"]:
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
@@ -400,6 +416,7 @@ def get_embedding_function(
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
def generate_multiple(query, prefix, user, func):
@@ -433,6 +450,7 @@ def get_sources_from_files(
reranking_function,
k_reranker,
r,
hybrid_bm25_weight,
hybrid_search,
full_context=False,
):
@@ -550,6 +568,7 @@ def get_sources_from_files(
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
)
except Exception as e:
log.debug(
@@ -681,6 +700,60 @@ def generate_openai_batch_embeddings(
return None
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
)
json_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
for _ in range(5):
r = requests.post(
url,
headers={
"Content-Type": "application/json",
"api-key": key,
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get("Retry-After", "1"))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if "data" in data:
return [elem["embedding"] for elem in data["data"]]
else:
raise Exception("Something went wrong :/")
return None
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
@@ -745,38 +818,33 @@ def generate_embeddings(
text = f"{prefix}{text}"
if engine == "ollama":
if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text,
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
else:
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(
model, text, url, key, prefix, user
)
else:
embeddings = generate_openai_batch_embeddings(
model, [text], url, key, prefix, user
)
embeddings = generate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings

View File

@@ -1,13 +1,21 @@
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from pinecone import ServerlessSpec
from pinecone import Pinecone, ServerlessSpec
# Add gRPC support for better performance (Pinecone best practice)
try:
from pinecone.grpc import PineconeGRPC
GRPC_AVAILABLE = True
except ImportError:
GRPC_AVAILABLE = False
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import (
VectorDBBase,
@@ -47,10 +55,25 @@ class PineconeClient(VectorDBBase):
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone gRPC client for improved performance
self.client = PineconeGRPC(
api_key=self.api_key, environment=self.environment, cloud=self.cloud
)
# Initialize Pinecone client for improved performance
if GRPC_AVAILABLE:
# Use gRPC client for better performance (Pinecone recommendation)
self.client = PineconeGRPC(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = True
log.info("Using Pinecone gRPC client for optimal performance")
else:
# Fallback to HTTP client with enhanced connection pooling
self.client = Pinecone(
api_key=self.api_key,
pool_threads=20, # Improved connection pool size
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = False
log.info("Using Pinecone HTTP client (gRPC not available)")
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -94,12 +117,53 @@ class PineconeClient(VectorDBBase):
log.info(f"Using existing Pinecone index '{self.index_name}'")
# Connect to the index
self.index = self.client.Index(self.index_name)
self.index = self.client.Index(
self.index_name,
pool_threads=20, # Enhanced connection pool for index operations
)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
def _retry_pinecone_operation(self, operation_func, max_retries=3):
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
for attempt in range(max_retries):
try:
return operation_func()
except Exception as e:
error_str = str(e).lower()
# Check if it's a retryable error (rate limits, network issues, timeouts)
is_retryable = any(
keyword in error_str
for keyword in [
"rate limit",
"quota",
"timeout",
"network",
"connection",
"unavailable",
"internal error",
"429",
"500",
"502",
"503",
"504",
]
)
if not is_retryable or attempt == max_retries - 1:
# Don't retry for non-retryable errors or on final attempt
raise
# Exponential backoff with jitter
delay = (2**attempt) + random.uniform(0, 1)
log.warning(
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {delay:.2f}s: {e}"
)
time.sleep(delay)
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
@@ -147,8 +211,8 @@ class PineconeClient(VectorDBBase):
metadatas = []
for match in matches:
metadata = match.get("metadata", {})
ids.append(match["id"])
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
@@ -174,7 +238,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
return len(response.matches) > 0
matches = getattr(response, "matches", []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
@@ -225,7 +290,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"Successfully inserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -256,7 +322,8 @@ class PineconeClient(VectorDBBase):
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
f"Successfully upserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -287,7 +354,8 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async insert batch: {result}")
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"Successfully async inserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -318,35 +386,10 @@ class PineconeClient(VectorDBBase):
log.error(f"Error in async upsert batch: {result}")
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
f"Successfully async upserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Perform a streaming upsert over gRPC for performance testing."""
if not items:
log.warning("No items to upsert via streaming")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Open a streaming upsert channel
stream = self.index.streaming_upsert()
try:
for point in points:
# send each point over the stream
stream.send(point)
# close the stream to finalize
stream.close()
log.info(
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
)
except Exception as e:
log.error(f"Error during streaming upsert: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
@@ -374,7 +417,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
if not query_response.matches:
matches = getattr(query_response, "matches", []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
@@ -384,13 +428,13 @@ class PineconeClient(VectorDBBase):
)
# Convert to GetResult format
get_result = self._result_to_get_result(query_response.matches)
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(match.score)
for match in query_response.matches
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
@@ -432,7 +476,8 @@ class PineconeClient(VectorDBBase):
include_metadata=True,
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
@@ -456,7 +501,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
@@ -482,10 +528,12 @@ class PineconeClient(VectorDBBase):
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
f"Deleted batch of {len(batch_ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
f"Successfully deleted {len(ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
elif filter:
@@ -516,12 +564,12 @@ class PineconeClient(VectorDBBase):
raise
def close(self):
"""Shut down the gRPC channel and thread pool."""
"""Shut down resources."""
try:
self.client.close()
log.info("Pinecone gRPC channel closed.")
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to close Pinecone gRPC channel: {e}")
log.warning(f"Failed to clean up Pinecone resources: {e}")
self._executor.shutdown(wait=True)
def __enter__(self):

View File

@@ -1,10 +1,20 @@
import logging
from typing import Optional, List
from typing import Optional, Literal
import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from open_webui.env import SRC_LOG_LEVELS
MODELS = Literal[
"sonar",
"sonar-pro",
"sonar-reasoning",
"sonar-reasoning-pro",
"sonar-deep-research",
]
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -14,6 +24,8 @@ def search_perplexity(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
model: MODELS = "sonar",
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
@@ -21,6 +33,9 @@ def search_perplexity(
api_key (str): A Perplexity API key
query (str): The query to search for
count (int): Maximum number of results to return
filter_list (Optional[list[str]]): List of domains to filter results
model (str): The Perplexity model to use (sonar, sonar-pro)
search_context_usage (str): Search context usage level (low, medium, high)
"""
@@ -33,7 +48,7 @@ def search_perplexity(
# Create payload for the API call
payload = {
"model": "sonar",
"model": model,
"messages": [
{
"role": "system",
@@ -43,6 +58,9 @@ def search_perplexity(
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
"web_search_options": {
"search_context_usage": search_context_usage,
},
}
headers = {

View File

@@ -42,7 +42,9 @@ def search_searchapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result.get("title"), snippet=result.get("snippet")
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -42,7 +42,9 @@ def search_serpapi(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result.get("title"), snippet=result.get("snippet")
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results[:count]
]

View File

@@ -8,6 +8,8 @@ from pathlib import Path
from pydub import AudioSegment
from pydub.silence import split_on_silence
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import aiohttp
import aiofiles
@@ -18,6 +20,7 @@ from fastapi import (
Depends,
FastAPI,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -93,12 +96,9 @@ def is_audio_conversion_required(file_path):
# File is AAC/mp4a audio, recommend mp3 conversion
return True
# If the codec name or file extension is in the supported formats
if (
codec_name in SUPPORTED_FORMATS
or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS
):
return False # Already supported
# If the codec name is in the supported formats
if codec_name in SUPPORTED_FORMATS:
return False
return True
except Exception as e:
@@ -527,11 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
def transcription_handler(request, file_path):
def transcription_handler(request, file_path, metadata):
filename = os.path.basename(file_path)
file_dir = os.path.dirname(file_path)
id = filename.split(".")[0]
metadata = metadata or {}
if request.app.state.config.STT_ENGINE == "":
if request.app.state.faster_whisper_model is None:
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -543,7 +545,7 @@ def transcription_handler(request, file_path):
file_path,
beam_size=5,
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
language=WHISPER_LANGUAGE,
language=metadata.get("language") or WHISPER_LANGUAGE,
)
log.info(
"Detected language '%s' with probability %f"
@@ -569,7 +571,14 @@ def transcription_handler(request, file_path):
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
},
files={"file": (filename, open(file_path, "rb"))},
data={"model": request.app.state.config.STT_MODEL},
data={
"model": request.app.state.config.STT_MODEL,
**(
{"language": metadata.get("language")}
if metadata.get("language")
else {}
),
},
)
r.raise_for_status()
@@ -777,8 +786,8 @@ def transcription_handler(request, file_path):
)
def transcribe(request: Request, file_path):
log.info(f"transcribe: {file_path}")
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
log.info(f"transcribe: {file_path} {metadata}")
if is_audio_conversion_required(file_path):
file_path = convert_audio_to_mp3(file_path)
@@ -804,7 +813,7 @@ def transcribe(request: Request, file_path):
with ThreadPoolExecutor() as executor:
# Submit tasks for each chunk_path
futures = [
executor.submit(transcription_handler, request, chunk_path)
executor.submit(transcription_handler, request, chunk_path, metadata)
for chunk_path in chunk_paths
]
# Gather results as they complete
@@ -812,10 +821,9 @@ def transcribe(request: Request, file_path):
try:
results.append(future.result())
except Exception as transcribe_exc:
log.exception(f"Error transcribing chunk: {transcribe_exc}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error during transcription.",
detail=f"Error transcribing chunk: {transcribe_exc}",
)
finally:
# Clean up only the temporary chunks, never the original file
@@ -897,6 +905,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
def transcription(
request: Request,
file: UploadFile = File(...),
language: Optional[str] = Form(None),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
@@ -926,7 +935,12 @@ def transcription(
f.write(contents)
try:
result = transcribe(request, file_path)
metadata = None
if language:
metadata = {"language": language}
result = transcribe(request, file_path, metadata)
return {
**result,

View File

@@ -19,12 +19,14 @@ from open_webui.models.auths import (
UserResponse,
)
from open_webui.models.users import Users
from open_webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.env import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
WEBUI_AUTH_COOKIE_SAME_SITE,
WEBUI_AUTH_COOKIE_SECURE,
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
@@ -299,7 +301,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
500, detail="Internal error occurred during LDAP user creation."
)
user = Auths.authenticate_user_by_trusted_header(email)
user = Auths.authenticate_user_by_email(email)
if user:
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
@@ -363,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
trusted_name = trusted_email
email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
name = email
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
)
if not Users.get_user_by_email(trusted_email.lower()):
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
if not Users.get_user_by_email(email.lower()):
await signup(
request,
response,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
),
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
user = Auths.authenticate_user_by_email(email)
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
group_names = request.headers.get(
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
).split(",")
group_names = [name.strip() for name in group_names if name.strip()]
if group_names:
Groups.sync_user_groups_by_group_names(user.id, group_names)
elif WEBUI_AUTH == False:
admin_email = "admin@localhost"
admin_password = "admin"

View File

@@ -76,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str,
page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
):
if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
)
@@ -194,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
############################
@router.get("/pinned", response_model=list[ChatResponse])
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
@@ -267,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=list[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
page: Optional[int] = None,
query: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
user=Depends(get_verified_user),
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
if page is None:
page = 1
limit = 60
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
chat_list = [
ChatTitleIdResponse(**chat.model_dump())
for chat in Chats.get_archived_chat_list_by_user_id(
user.id,
filter=filter,
skip=skip,
limit=limit,
)
]
return chat_list
############################

View File

@@ -1,6 +1,7 @@
import logging
import os
import uuid
import json
from fnmatch import fnmatch
from pathlib import Path
from typing import Optional
@@ -10,6 +11,7 @@ from fastapi import (
APIRouter,
Depends,
File,
Form,
HTTPException,
Request,
UploadFile,
@@ -85,20 +87,33 @@ def has_access_to_file(
def upload_file(
request: Request,
file: UploadFile = File(...),
user=Depends(get_verified_user),
file_metadata: dict = None,
metadata: Optional[dict | str] = Form(None),
process: bool = Query(True),
internal: bool = False,
user=Depends(get_verified_user),
knowledge_id: Optional[str] = Form(None)
):
log.info(f"file.content_type: {file.content_type}")
file_metadata = file_metadata if file_metadata else {}
if isinstance(metadata, str):
try:
metadata = json.loads(metadata)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
)
file_metadata = metadata if metadata else {}
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
file_extension = os.path.splitext(filename)[1]
if request.app.state.config.ALLOWED_FILE_EXTENSIONS:
# Remove the leading dot from the file extension
file_extension = file_extension[1:] if file_extension else ""
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
]
@@ -146,21 +161,16 @@ def upload_file(
"video/webm"
}:
file_path = Storage.get_file(file_path)
result = transcribe(request, file_path)
result = transcribe(request, file_path, file_metadata)
process_file(
request,
ProcessFileForm(file_id=id, content=result.get("text", ""), knowledge_id=knowledge_id),
user=user,
)
elif file.content_type not in [
"image/png",
"image/jpeg",
"image/gif",
"video/mp4",
"video/ogg",
"video/quicktime",
]:
elif (not file.content_type.startswith(("image/", "video/"))) or (
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
):
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
else:
log.info(
@@ -191,7 +201,7 @@ def upload_file(
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
)

View File

@@ -1,5 +1,8 @@
import os
import re
import logging
import aiohttp
from pathlib import Path
from typing import Optional
@@ -9,12 +12,18 @@ from open_webui.models.functions import (
FunctionResponse,
Functions,
)
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
from open_webui.utils.plugin import (
load_function_module_by_id,
replace_imports,
get_function_module_from_cache,
)
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.env import SRC_LOG_LEVELS
from pydantic import BaseModel, HttpUrl
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -42,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_function_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
function_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the function"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": function_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
############################
# SyncFunctions
############################
class SyncFunctionsForm(FunctionForm):
functions: list[FunctionModel] = []
@router.post("/sync", response_model=Optional[FunctionModel])
async def sync_functions(
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
):
return Functions.sync_functions(user.id, form_data.functions)
############################
# CreateNewFunction
############################
@@ -262,8 +362,9 @@ async def get_function_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -287,8 +388,9 @@ async def update_function_valves_by_id(
):
function = Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
@@ -347,8 +449,9 @@ async def get_function_user_valves_spec_by_id(
):
function = Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
@@ -368,8 +471,9 @@ async def update_function_user_valves_by_id(
function = Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
)
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves

View File

@@ -333,10 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
return [
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
]
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
# TODO - get models from comfyui
@@ -450,7 +451,7 @@ def load_url_image_data(url, headers=None):
return None
def upload_image(request, image_metadata, image_data, content_type, user):
def upload_image(request, image_data, content_type, metadata, user):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
@@ -459,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
"content-type": content_type,
},
)
file_item = upload_file(request, file, user, file_metadata=image_metadata)
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
return url
@@ -526,7 +527,7 @@ async def image_generations(
else:
image_data, content_type = load_b64_image_data(image["b64_json"])
url = upload_image(request, data, image_data, content_type, user)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
@@ -560,7 +561,7 @@ async def image_generations(
image_data, content_type = load_b64_image_data(
image["bytesBase64Encoded"]
)
url = upload_image(request, data, image_data, content_type, user)
url = upload_image(request, image_data, content_type, data, user)
images.append({"url": url})
return images
@@ -611,9 +612,9 @@ async def image_generations(
image_data, content_type = load_url_image_data(image["url"], headers)
url = upload_image(
request,
form_data.model_dump(exclude_none=True),
image_data,
content_type,
form_data.model_dump(exclude_none=True),
user,
)
images.append({"url": url})
@@ -664,9 +665,9 @@ async def image_generations(
image_data, content_type = load_b64_image_data(image)
url = upload_image(
request,
{**data, "info": res["info"]},
image_data,
content_type,
{**data, "info": res["info"]},
user,
)
images.append({"url": url})

View File

@@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
):
raise HTTPException(
@@ -159,9 +158,8 @@ async def update_note_by_id(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(
@@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if (
user.role != "admin"
and user.id != note.user_id
if user.role != "admin" or (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
):
raise HTTPException(

View File

@@ -9,6 +9,8 @@ import os
import random
import re
import time
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import aiohttp
@@ -300,6 +302,22 @@ async def update_config(
}
def merge_ollama_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
@cached(ttl=1)
async def get_all_models(request: Request, user: UserModel = None):
log.info("get_all_models()")
@@ -364,23 +382,8 @@ async def get_all_models(request: Request, user: UserModel = None):
if connection_type:
model["connection_type"] = connection_type
def merge_models_lists(model_lists):
merged_models = {}
for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list:
id = model["model"]
if id not in merged_models:
model["urls"] = [idx]
merged_models[id] = model
else:
merged_models[id]["urls"].append(idx)
return list(merged_models.values())
models = {
"models": merge_models_lists(
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
@@ -388,6 +391,22 @@ async def get_all_models(request: Request, user: UserModel = None):
)
}
try:
loaded_models = await get_ollama_loaded_models(request, user=user)
expires_map = {
m["name"]: m["expires_at"]
for m in loaded_models["models"]
if "expires_at" in m
}
for m in models["models"]:
if m["name"] in expires_map:
# Parse ISO8601 datetime with offset, get unix timestamp as int
dt = datetime.fromisoformat(expires_map[m["name"]])
m["expires_at"] = int(dt.timestamp())
except Exception as e:
log.debug(f"Failed to get loaded models: {e}")
else:
models = {"models": []}
@@ -468,6 +487,68 @@ async def get_ollama_tags(
return models
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = []
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
):
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
else:
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
enable = api_config.get("enable", True)
key = api_config.get("key", None)
if enable:
request_tasks.append(
send_get_request(f"{url}/api/ps", key, user=user)
)
else:
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
)
prefix_id = api_config.get("prefix_id", None)
for model in response.get("models", []):
if prefix_id:
model["model"] = f"{prefix_id}.{model['model']}"
models = {
"models": merge_ollama_models_lists(
map(
lambda response: response.get("models", []) if response else None,
responses,
)
)
}
else:
models = {"models": []}
return models
@router.get("/api/version")
@router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
@@ -541,36 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
return {"version": False}
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if request.app.state.config.ENABLE_OLLAMA_API:
request_tasks = [
send_get_request(
f"{url}/api/ps",
request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx),
request.app.state.config.OLLAMA_API_CONFIGS.get(
url, {}
), # Legacy support
).get("key", None),
user=user,
)
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
]
responses = await asyncio.gather(*request_tasks)
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
class ModelNameForm(BaseModel):
name: str
@router.post("/api/unload")
async def unload_model(
request: Request,
form_data: ModelNameForm,
user=Depends(get_admin_user),
):
model_name = form_data.name
if not model_name:
raise HTTPException(
status_code=400, detail="Missing 'name' of model to unload."
)
# Refresh/load models if needed, get mapping from name to URLs
await get_all_models(request, user=user)
models = request.app.state.OLLAMA_MODELS
# Canonicalize model name (if not supplied with version)
if ":" not in model_name:
model_name = f"{model_name}:latest"
if model_name not in models:
raise HTTPException(
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
)
url_indices = models[model_name]["urls"]
# Send unload to ALL url_indices
results = []
errors = []
for idx in url_indices:
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
)
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
prefix_id = api_config.get("prefix_id", None)
if prefix_id and model_name.startswith(f"{prefix_id}."):
model_name = model_name[len(f"{prefix_id}.") :]
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
try:
res = await send_post_request(
url=f"{url}/api/generate",
payload=json.dumps(payload),
stream=False,
key=key,
user=user,
)
results.append({"url_idx": idx, "success": True, "response": res})
except Exception as e:
log.exception(f"Failed to unload model on node {idx}: {e}")
errors.append({"url_idx": idx, "success": False, "error": str(e)})
if len(errors) > 0:
raise HTTPException(
status_code=500,
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
)
return {"status": True}
@router.post("/api/pull")
@router.post("/api/pull/{url_idx}")
async def pull_model(
@@ -1164,13 +1283,14 @@ async def generate_chat_completion(
params = model_info.params.model_dump()
if params:
if payload.get("options") is None:
payload["options"] = {}
system = params.pop("system", None)
# Unlike OpenAI, Ollama does not support params directly in the body
payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"]
params, (payload.get("options", {}) or {})
)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@@ -1352,8 +1472,10 @@ async def generate_openai_chat_completion(
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if user.role == "user":

View File

@@ -715,8 +715,12 @@ async def generate_chat_completion(
model_id = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
if params:
system = params.pop("system", None)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
# Check if user has access to the model
if not bypass_filter and user.role == "user":
@@ -883,6 +887,88 @@ async def generate_chat_completion(
await session.close()
async def embeddings(request: Request, form_data: dict, user):
"""
Calls the embeddings endpoint for OpenAI-compatible providers.
Args:
request (Request): The FastAPI request context.
form_data (dict): OpenAI-compatible embeddings payload.
user (UserModel): The authenticated user.
Returns:
dict: OpenAI-compatible embeddings response.
"""
idx = 0
# Prepare payload/body
body = json.dumps(form_data)
# Find correct backend url/key based on model
await get_all_models(request, user=user)
model_id = form_data.get("model")
models = request.app.state.OPENAI_MODELS
if model_id in models:
idx = models[model_id]["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/embeddings",
data=body,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user
else {}
),
},
)
r.raise_for_status()
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
detail = None
if r is not None:
try:
res = await r.json()
if "error" in res:
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:
detail = f"External: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"""

View File

@@ -254,6 +254,11 @@ async def get_embedding_config(request: Request, collectionForm: Optional[Collec
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
}),
"azure_openai_config": rag_config.get("azure_openai_config", {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
}),
}
@@ -267,9 +272,16 @@ class OllamaConfigForm(BaseModel):
key: str
class AzureOpenAIConfigForm(BaseModel):
url: str
key: str
version: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
ollama_config: Optional[OllamaConfigForm] = None
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
embedding_batch_size: Optional[int] = 1
@@ -405,14 +417,18 @@ async def update_embedding_config(
gc.collect()
torch.cuda.empty_cache()
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
"ollama",
"openai",
"azure_openai",
]:
if form_data.openai_config is not None:
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
form_data.openai_config.url
)
request.app.state.config.RAG_OPENAI_API_KEY = (
form_data.openai_config.key
)
if form_data.ollama_config is not None:
request.app.state.config.RAG_OLLAMA_BASE_URL = (
@@ -422,6 +438,61 @@ async def update_embedding_config(
form_data.ollama_config.key
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
if form_data.azure_openai_config is not None:
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
form_data.azure_openai_config.url
)
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
form_data.azure_openai_config.key
)
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
form_data.azure_openai_config.version
)
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
form_data.embedding_batch_size
)
@@ -440,14 +511,27 @@ async def update_embedding_config(
(
request.app.state.config.RAG_OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_BASE_URL
else (
request.app.state.config.RAG_OLLAMA_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
)
),
(
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
else (
request.app.state.config.RAG_OLLAMA_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
)
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
azure_api_version=(
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
else None
),
)
# add model to state for reloading on startup
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
@@ -470,9 +554,13 @@ async def update_embedding_config(
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
},
"azure_openai_config": {
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
},
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
"message": "Embedding configuration updated globally.",
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
@@ -508,9 +596,19 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
"TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER),
"RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD),
"HYBRID_BM25_WEIGHT": rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
"PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES),
"DATALAB_MARKER_API_KEY": rag_config.get("DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY),
"DATALAB_MARKER_LANGS": rag_config.get("DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS),
"DATALAB_MARKER_SKIP_CACHE": rag_config.get("DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE),
"DATALAB_MARKER_FORCE_OCR": rag_config.get("DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR),
"DATALAB_MARKER_PAGINATE": rag_config.get("DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE),
"DATALAB_MARKER_STRIP_EXISTING_OCR": rag_config.get("DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR),
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": rag_config.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION),
"DATALAB_MARKER_USE_LLM": rag_config.get("DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM),
"DATALAB_MARKER_OUTPUT_FORMAT": rag_config.get("DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT),
"EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL),
"EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY),
"TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL),
@@ -546,6 +644,7 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
"WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
"BYPASS_WEB_SEARCH_WEB_LOADER": web_config.get("BYPASS_WEB_SEARCH_WEB_LOADER", request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER),
"SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL),
"YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL),
"YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME),
@@ -570,6 +669,8 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
"BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
"EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY),
"PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY),
"PERPLEXITY_MODEL": web_config.get("PERPLEXITY_MODEL", request.app.state.config.PERPLEXITY_MODEL),
"PERPLEXITY_SEARCH_CONTEXT_USAGE": web_config.get("PERPLEXITY_SEARCH_CONTEXT_USAGE", request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE),
"SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID),
"SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK),
"WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE),
@@ -603,6 +704,7 @@ class WebConfig(BaseModel):
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
SEARXNG_QUERY_URL: Optional[str] = None
YACY_QUERY_URL: Optional[str] = None
YACY_USERNAME: Optional[str] = None
@@ -627,6 +729,8 @@ class WebConfig(BaseModel):
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
EXA_API_KEY: Optional[str] = None
PERPLEXITY_API_KEY: Optional[str] = None
PERPLEXITY_MODEL: Optional[str] = None
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
SOUGOU_API_SID: Optional[str] = None
SOUGOU_API_SK: Optional[str] = None
WEB_LOADER_ENGINE: Optional[str] = None
@@ -656,10 +760,20 @@ class ConfigForm(BaseModel):
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
TOP_K_RERANKER: Optional[int] = None
RELEVANCE_THRESHOLD: Optional[float] = None
HYBRID_BM25_WEIGHT: Optional[float] = None
# Content extraction settings
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
PDF_EXTRACT_IMAGES: Optional[bool] = None
DATALAB_MARKER_API_KEY: Optional[str] = None
DATALAB_MARKER_LANGS: Optional[str] = None
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
DATALAB_MARKER_PAGINATE: Optional[bool] = None
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
DATALAB_MARKER_USE_LLM: Optional[bool] = None
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
@@ -853,6 +967,11 @@ async def update_rag_config(
if form_data.RELEVANCE_THRESHOLD is not None
else request.app.state.config.RELEVANCE_THRESHOLD
)
request.app.state.config.HYBRID_BM25_WEIGHT = (
form_data.HYBRID_BM25_WEIGHT
if form_data.HYBRID_BM25_WEIGHT is not None
else request.app.state.config.HYBRID_BM25_WEIGHT
)
# Content extraction settings
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
@@ -865,6 +984,51 @@ async def update_rag_config(
if form_data.PDF_EXTRACT_IMAGES is not None
else request.app.state.config.PDF_EXTRACT_IMAGES
)
request.app.state.config.DATALAB_MARKER_API_KEY = (
form_data.DATALAB_MARKER_API_KEY
if form_data.DATALAB_MARKER_API_KEY is not None
else request.app.state.config.DATALAB_MARKER_API_KEY
)
request.app.state.config.DATALAB_MARKER_LANGS = (
form_data.DATALAB_MARKER_LANGS
if form_data.DATALAB_MARKER_LANGS is not None
else request.app.state.config.DATALAB_MARKER_LANGS
)
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
form_data.DATALAB_MARKER_SKIP_CACHE
if form_data.DATALAB_MARKER_SKIP_CACHE is not None
else request.app.state.config.DATALAB_MARKER_SKIP_CACHE
)
request.app.state.config.DATALAB_MARKER_FORCE_OCR = (
form_data.DATALAB_MARKER_FORCE_OCR
if form_data.DATALAB_MARKER_FORCE_OCR is not None
else request.app.state.config.DATALAB_MARKER_FORCE_OCR
)
request.app.state.config.DATALAB_MARKER_PAGINATE = (
form_data.DATALAB_MARKER_PAGINATE
if form_data.DATALAB_MARKER_PAGINATE is not None
else request.app.state.config.DATALAB_MARKER_PAGINATE
)
request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = (
form_data.DATALAB_MARKER_STRIP_EXISTING_OCR
if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None
else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
)
request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
)
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
form_data.DATALAB_MARKER_OUTPUT_FORMAT
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
)
request.app.state.config.DATALAB_MARKER_USE_LLM = (
form_data.DATALAB_MARKER_USE_LLM
if form_data.DATALAB_MARKER_USE_LLM is not None
else request.app.state.config.DATALAB_MARKER_USE_LLM
)
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
form_data.EXTERNAL_DOCUMENT_LOADER_URL
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
@@ -1046,6 +1210,9 @@ async def update_rag_config(
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
)
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
)
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
@@ -1082,6 +1249,10 @@ async def update_rag_config(
)
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
)
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
@@ -1132,9 +1303,19 @@ async def update_rag_config(
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
# Content extraction settings
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
# Content extraction settings
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
@@ -1170,6 +1351,7 @@ async def update_rag_config(
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
@@ -1194,6 +1376,8 @@ async def update_rag_config(
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
@@ -1268,11 +1452,14 @@ def save_docs_to_vector_db(
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
openai_api_base_url = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
openai_api_key = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_KEY)
ollama_base_url = rag_config.get("ollama_config", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
ollama_api_key = rag_config.get("ollama_config", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
azure_openai_url = rag_config.get("azure_openai", {}).get("url", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
azure_openai_key = rag_config.get("azure_openai", {}).get("key", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
azure_openai_version = rag_config.get("azure_openai", {}).get("version", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
result = VECTOR_DB_CLIENT.query(
@@ -1354,20 +1541,29 @@ def save_docs_to_vector_db(
log.info(f"adding to collection {collection_name}")
embedding_function = get_embedding_function(
embedding_engine,
embedding_model,
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.ef,
(
openai_api_base_url
if embedding_engine == "openai"
else ollama_base_url
else (
ollama_base_url
if embedding_engine == "ollama"
else azure_openai_url
)
),
(
openai_api_key
if embedding_engine == "openai"
else ollama_api_key
request.app.state.config.RAG_OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.RAG_OLLAMA_API_KEY
),
embedding_batch_size,
azure_api_version=(
azure_openai_version
if embedding_engine == "azure_openai"
else None
),
)
embeddings = embedding_function(
@@ -1439,6 +1635,33 @@ def process_file(
content_extraction_engine = rag_config.get(
"CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE
)
datalab_marker_api_key=rag_config.get(
"DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY
)
datalab_marker_langs=rag_config.get(
"DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS
)
datalab_marker_skip_cache=rag_config.get(
"DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE
)
datalab_marker_force_ocr=rag_config.get(
"DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR
)
datalab_marker_paginate=rag_config.get(
"DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE
)
datalab_marker_strip_existing_ocr=rag_config.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
)
datalab_marker_disable_image_extraction=rag_config.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
)
datalab_marker_use_llm=rag_config.get(
"DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM
)
datalab_marker_output_format=rag_config.get(
"DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
)
external_document_loader_url = rag_config.get(
"EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
)
@@ -1537,6 +1760,15 @@ def process_file(
file_path = Storage.get_file(file_path)
loader = Loader(
engine=content_extraction_engine,
DATALAB_MARKER_API_KEY=datalab_marker_api_key,
DATALAB_MARKER_LANGS=datalab_marker_langs,
DATALAB_MARKER_SKIP_CACHE=datalab_marker_skip_cache,
DATALAB_MARKER_FORCE_OCR=datalab_marker_force_ocr,
DATALAB_MARKER_PAGINATE=datalab_marker_paginate,
DATALAB_MARKER_STRIP_EXISTING_OCR=datalab_marker_strip_existing_ocr,
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=datalab_marker_disable_image_extraction,
DATALAB_MARKER_USE_LLM=datalab_marker_use_llm,
DATALAB_MARKER_OUTPUT_FORMAT=datalab_marker_output_format,
EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url,
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
TIKA_SERVER_URL=tika_server_url,
@@ -1961,19 +2193,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "exa":
return search_exa(
request.app.state.config.EXA_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "perplexity":
return search_perplexity(
request.app.state.config.PERPLEXITY_API_KEY,
query,
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
model=request.app.state.config.PERPLEXITY_MODEL,
search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
)
elif engine == "sougou":
if (
@@ -2052,13 +2279,29 @@ async def process_web_search(
)
try:
loader = get_web_loader(
urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
docs = [
Document(
page_content=result.snippet,
metadata={
"source": result.link,
"title": result.title,
"snippet": result.snippet,
"link": result.link,
},
)
for result in search_results
if hasattr(result, "snippet")
]
else:
loader = get_web_loader(
urls,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
)
docs = await loader.aload()
urls = [
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
] # only keep the urls returned by the loader
@@ -2148,6 +2391,11 @@ def query_doc_handler(
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
hybrid_bm25_weight=(
form_data.hybrid_bm25_weight
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
user=user,
)
else:
@@ -2174,6 +2422,7 @@ class QueryCollectionsForm(BaseModel):
k_reranker: Optional[int] = None
r: Optional[float] = None
hybrid: Optional[bool] = None
hybrid_bm25_weight: Optional[float] = None
@router.post("/query/collection")
@@ -2199,6 +2448,11 @@ def query_collection_handler(
if form_data.r
else request.app.state.config.RELEVANCE_THRESHOLD
),
hybrid_bm25_weight=(
form_data.hybrid_bm25_weight
if form_data.hybrid_bm25_weight
else request.app.state.config.HYBRID_BM25_WEIGHT
),
)
else:
return query_collection(

View File

@@ -9,6 +9,7 @@ import re
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.task import (
title_generation_template,
follow_up_generation_template,
query_generation_template,
image_prompt_generation_template,
autocomplete_generation_template,
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
from open_webui.config import (
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
ENABLE_AUTOCOMPLETE_GENERATION: bool
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
TAGS_GENERATION_PROMPT_TEMPLATE: str
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
ENABLE_FOLLOW_UP_GENERATION: bool
ENABLE_TAGS_GENERATION: bool
ENABLE_SEARCH_QUERY_GENERATION: bool
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
@@ -94,6 +100,13 @@ async def update_task_config(
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
form_data.ENABLE_FOLLOW_UP_GENERATION
)
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
)
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
)
@@ -133,6 +146,8 @@ async def update_task_config(
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -231,6 +246,86 @@ async def generate_title(
)
@router.post("/follow_up/completions")
async def generate_follow_ups(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"detail": "Follow-up generation is disabled"},
)
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
task_model_id = get_task_model_id(
model_id,
request.app.state.config.TASK_MODEL,
request.app.state.config.TASK_MODEL_EXTERNAL,
models,
)
log.debug(
f"generating chat title using model {task_model_id} for user {user.email} "
)
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
else:
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
content = follow_up_generation_template(
template,
form_data["messages"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
)
payload = {
"model": task_model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"metadata": {
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
"task": str(TASKS.FOLLOW_UP_GENERATION),
"task_body": form_data,
"chat_id": form_data.get("chat_id", None),
},
}
# Process the payload through the pipeline
try:
payload = await process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e:
raise e
try:
return await generate_chat_completion(request, form_data=payload, user=user)
except Exception as e:
log.error("Exception occurred", exc_info=True)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "An internal error has occurred."},
)
@router.post("/tags/completions")
async def generate_chat_tags(
request: Request, form_data: dict, user=Depends(get_verified_user)

View File

@@ -2,6 +2,9 @@ import logging
from pathlib import Path
from typing import Optional
import time
import re
import aiohttp
from pydantic import BaseModel, HttpUrl
from open_webui.models.tools import (
ToolForm,
@@ -21,6 +24,7 @@ from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.tools import get_tool_servers_data
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@@ -51,11 +55,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
**{
"id": f"server:{server['idx']}",
"user_id": f"server:{server['idx']}",
"name": server["openapi"]
"name": server.get("openapi", {})
.get("info", {})
.get("title", "Tool Server"),
"meta": {
"description": server["openapi"]
"description": server.get("openapi", {})
.get("info", {})
.get("description", ""),
},
@@ -95,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
return tools
############################
# LoadFunctionFromLink
############################
class LoadUrlForm(BaseModel):
url: HttpUrl
def github_url_to_raw_url(url: str) -> str:
# Handle 'tree' (folder) URLs (add main.py at the end)
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
if m1:
org, repo, branch, path = m1.groups()
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
# Handle 'blob' (file) URLs
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
if m2:
org, repo, branch, path = m2.groups()
return (
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
)
# No match; return as-is
return url
@router.post("/load/url", response_model=Optional[dict])
async def load_tool_from_url(
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
):
# NOTE: This is NOT a SSRF vulnerability:
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
# and does NOT accept untrusted user input. Access is enforced by authentication.
url = str(form_data.url)
if not url:
raise HTTPException(status_code=400, detail="Please enter a valid URL")
url = github_url_to_raw_url(url)
url_parts = url.rstrip("/").split("/")
file_name = url_parts[-1]
tool_name = (
file_name[:-3]
if (
file_name.endswith(".py")
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
)
else url_parts[-2] if len(url_parts) > 1 else "function"
)
try:
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers={"Content-Type": "application/json"}
) as resp:
if resp.status != 200:
raise HTTPException(
status_code=resp.status, detail="Failed to fetch the tool"
)
data = await resp.text()
if not data:
raise HTTPException(
status_code=400, detail="No data received from the URL"
)
return {
"name": tool_name,
"content": data,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
############################
# ExportTools
############################

View File

@@ -165,22 +165,6 @@ async def update_default_user_permissions(
return request.app.state.config.USER_PERMISSIONS
############################
# UpdateUserRole
############################
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
############################
# GetUserSettingsBySessionUser
############################
@@ -333,11 +317,22 @@ async def update_user_by_id(
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
if first_user and user_id == first_user.id and session_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
if first_user:
if user_id == first_user.id:
if session_user.id != user_id:
# If the user trying to update is the primary admin, and they are not the primary admin themselves
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
if form_data.role != "admin":
# If the primary admin is trying to change their own role, prevent it
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
except Exception as e:
log.error(f"Error checking primary admin status: {e}")
raise HTTPException(
@@ -365,6 +360,7 @@ async def update_user_by_id(
updated_user = Users.update_user_by_id(
user_id,
{
"role": form_data.role,
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,

View File

@@ -269,11 +269,6 @@ tbody + tbody {
margin-bottom: 0;
}
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
.markdown-section p + ul {
margin-top: 0;
}
/* List item styles */
.markdown-section li {
padding: 2px;

View File

@@ -2,6 +2,7 @@ import os
import shutil
import json
import logging
import re
from abc import ABC, abstractmethod
from typing import BinaryIO, Tuple, Dict
@@ -136,6 +137,11 @@ class S3StorageProvider(StorageProvider):
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
@staticmethod
def sanitize_tag_value(s: str) -> str:
"""Only include S3 allowed characters."""
return re.sub(r"[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]", "", s)
def upload_file(
self, file: BinaryIO, filename: str, tags: Dict[str, str]
) -> Tuple[bytes, str]:
@@ -145,7 +151,15 @@ class S3StorageProvider(StorageProvider):
try:
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
if S3_ENABLE_TAGGING and tags:
tagging = {"TagSet": [{"Key": k, "Value": v} for k, v in tags.items()]}
sanitized_tags = {
self.sanitize_tag_value(k): self.sanitize_tag_value(v)
for k, v in tags.items()
}
tagging = {
"TagSet": [
{"Key": k, "Value": v} for k, v in sanitized_tags.items()
]
}
self.s3_client.put_object_tagging(
Bucket=self.bucket_name,
Key=s3_key,

View File

@@ -40,7 +40,10 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.models import get_all_models, check_model_access
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
@@ -317,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -392,8 +390,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
}
)
function_module, _, _ = load_function_module_by_id(action_id)
request.app.state.FUNCTIONS[action_id] = function_module
function_module, _, _ = get_function_module_from_cache(request, action_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
@@ -422,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
try:
if hasattr(function_module, "UserValves"):

View File

@@ -0,0 +1,90 @@
import random
import logging
import sys
from fastapi import Request
from open_webui.models.users import UserModel
from open_webui.models.models import Models
from open_webui.utils.models import check_model_access
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.routers.openai import embeddings as openai_embeddings
from open_webui.routers.ollama import (
embeddings as ollama_embeddings,
GenerateEmbeddingsForm,
)
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def generate_embeddings(
request: Request,
form_data: dict,
user: UserModel,
bypass_filter: bool = False,
):
"""
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
Args:
request (Request): The FastAPI request context.
form_data (dict): The input data sent to the endpoint.
user (UserModel): The authenticated user.
bypass_filter (bool): If True, disables access filtering (default False).
Returns:
dict: The embeddings response, following OpenAI API compatibility.
"""
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
# Attach extra metadata from request.state if present
if hasattr(request.state, "metadata"):
if "metadata" not in form_data:
form_data["metadata"] = request.state.metadata
else:
form_data["metadata"] = {
**form_data["metadata"],
**request.state.metadata,
}
# If "direct" flag present, use only that model
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
}
else:
models = request.app.state.MODELS
model_id = form_data.get("model")
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
# Access filtering
if not getattr(request.state, "direct", False):
if not bypass_filter and user.role == "user":
check_model_access(user, model)
# Ollama backend
if model.get("owned_by") == "ollama":
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
response = await ollama_embeddings(
request=request,
form_data=GenerateEmbeddingsForm(**ollama_payload),
user=user,
)
return convert_embedding_response_ollama_to_openai(response)
# Default: OpenAI or compatible backend
return await openai_embeddings(
request=request,
form_data=form_data,
user=user,
)

View File

@@ -1,7 +1,10 @@
import inspect
import logging
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.models.functions import Functions
from open_webui.env import SRC_LOG_LEVELS
@@ -9,14 +12,13 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module(request, function_id):
def get_function_module(request, function_id, load_from_db=True):
"""
Get the function module by its ID.
"""
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
function_module, _, _ = get_function_module_from_cache(
request, function_id, load_from_db
)
return function_module
@@ -37,14 +39,17 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
for function in Functions.get_functions_by_type("filter", active_only=True)
]
for filter_id in active_filter_ids:
def get_active_status(filter_id):
function_module = get_function_module(request, filter_id)
if getattr(function_module, "toggle", None) and (
filter_id not in enabled_filter_ids
):
active_filter_ids.remove(filter_id)
continue
if getattr(function_module, "toggle", None):
return filter_id in (enabled_filter_ids or [])
return True
active_filter_ids = [
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
]
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority)
@@ -63,7 +68,9 @@ async def process_filter_functions(
if not filter:
continue
function_module = get_function_module(request, filter_id)
function_module = get_function_module(
request, filter_id, load_from_db=(filter_type != "stream")
)
# Prepare handler function
handler = getattr(function_module, filter_type, None)
if not handler:

View File

@@ -32,6 +32,7 @@ from open_webui.socket.main import (
from open_webui.routers.tasks import (
generate_queries,
generate_title,
generate_follow_ups,
generate_image_prompt,
generate_chat_tags,
)
@@ -41,6 +42,7 @@ from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.routers.memories import query_memory, QueryMemoryForm
from open_webui.utils.webhook import post_webhook
@@ -251,7 +253,12 @@ async def chat_completion_tools_handler(
"name": (f"TOOL:{tool_name}"),
},
"document": [tool_result],
"metadata": [{"source": (f"TOOL:{tool_name}")}],
"metadata": [
{
"source": (f"TOOL:{tool_name}"),
"parameters": tool_function_params,
}
],
}
)
else:
@@ -290,6 +297,45 @@ async def chat_completion_tools_handler(
return body, {"sources": sources}
async def chat_memory_handler(
request: Request, form_data: dict, extra_params: dict, user
):
try:
results = await query_memory(
request,
QueryMemoryForm(
**{
"content": get_last_user_message(form_data["messages"]) or "",
"k": 3,
}
),
user,
)
except Exception as e:
log.debug(e)
results = None
user_context = ""
if results and hasattr(results, "documents"):
if results.documents and len(results.documents) > 0:
for doc_idx, doc in enumerate(results.documents[0]):
created_at_date = "Unknown Date"
if results.metadatas[0][doc_idx].get("created_at"):
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
created_at_date = time.strftime(
"%Y-%m-%d", time.localtime(created_at_timestamp)
)
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
form_data["messages"] = add_or_update_system_message(
f"User Context:\n{user_context}\n", form_data["messages"], append=True
)
return form_data
async def chat_web_search_handler(
request: Request, form_data: dict, extra_params: dict, user
):
@@ -389,6 +435,7 @@ async def chat_web_search_handler(
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)
elif results.get("docs"):
@@ -400,6 +447,7 @@ async def chat_web_search_handler(
"name": ", ".join(queries),
"type": "web_search",
"urls": results["filenames"],
"queries": queries,
}
)
@@ -617,6 +665,7 @@ async def chat_completion_files_handler(
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=hybrid_search,
full_context=full_context,
),
@@ -631,6 +680,32 @@ async def chat_completion_files_handler(
def apply_params_to_form_data(form_data, model):
params = form_data.pop("params", {})
custom_params = params.pop("custom_params", {})
open_webui_params = {
"stream_response": bool,
"function_calling": str,
"system": str,
}
for key in list(params.keys()):
if key in open_webui_params:
del params[key]
if custom_params:
# Attempt to parse custom_params if they are strings
for key, value in custom_params.items():
if isinstance(value, str):
try:
# Attempt to parse the string as JSON
custom_params[key] = json.loads(value)
except json.JSONDecodeError:
# If it fails, keep the original string
pass
# If custom_params are provided, merge them into params
params = deep_update(params, custom_params)
if model.get("ollama"):
form_data["options"] = params
@@ -640,29 +715,10 @@ def apply_params_to_form_data(form_data, model):
if "keep_alive" in params:
form_data["keep_alive"] = params["keep_alive"]
else:
if "seed" in params and params["seed"] is not None:
form_data["seed"] = params["seed"]
if "stop" in params and params["stop"] is not None:
form_data["stop"] = params["stop"]
if "temperature" in params and params["temperature"] is not None:
form_data["temperature"] = params["temperature"]
if "max_tokens" in params and params["max_tokens"] is not None:
form_data["max_tokens"] = params["max_tokens"]
if "top_p" in params and params["top_p"] is not None:
form_data["top_p"] = params["top_p"]
if "frequency_penalty" in params and params["frequency_penalty"] is not None:
form_data["frequency_penalty"] = params["frequency_penalty"]
if "presence_penalty" in params and params["presence_penalty"] is not None:
form_data["presence_penalty"] = params["presence_penalty"]
if "reasoning_effort" in params and params["reasoning_effort"] is not None:
form_data["reasoning_effort"] = params["reasoning_effort"]
if isinstance(params, dict):
for key, value in params.items():
if value is not None:
form_data[key] = value
if "logit_bias" in params and params["logit_bias"] is not None:
try:
@@ -676,7 +732,6 @@ def apply_params_to_form_data(form_data, model):
async def process_chat_payload(request, form_data, user, metadata, model):
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@@ -686,12 +741,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
@@ -788,6 +838,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
features = form_data.pop("features", None)
if features:
if "memory" in features and features["memory"]:
form_data = await chat_memory_handler(
request, form_data, extra_params, user
)
if "web_search" in features and features["web_search"]:
form_data = await chat_web_search_handler(
request, form_data, extra_params, user
@@ -890,6 +945,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
for doc_context, doc_meta in zip(
source["document"], source["metadata"]
):
source_name = source.get("source", {}).get("name", None)
citation_id = (
doc_meta.get("source", None)
or source.get("source", {}).get("id", None)
@@ -897,7 +953,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
)
if citation_id not in citation_idx:
citation_idx[citation_id] = len(citation_idx) + 1
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
context_string += (
f'<source id="{citation_idx[citation_id]}"'
+ (f' name="{source_name}"' if source_name else "")
+ f">{doc_context}</source>\n"
)
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
@@ -968,7 +1028,7 @@ async def process_chat_response(
message = message_map.get(metadata["message_id"]) if message_map else None
if message:
message_list = get_message_list(message_map, message.get("id"))
message_list = get_message_list(message_map, metadata["message_id"])
# Remove details tags and files from the messages.
# as get_message_list creates a new list, it does not affect
@@ -985,7 +1045,7 @@ async def process_chat_response(
if isinstance(content, str):
content = re.sub(
r"<details\b[^>]*>.*?<\/details>",
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
"",
content,
flags=re.S | re.I,
@@ -994,12 +1054,67 @@ async def process_chat_response(
messages.append(
{
**message,
"role": message["role"],
"role": message.get(
"role", "assistant"
), # Safe fallback for missing role
"content": content,
}
)
if tasks and messages:
if (
TASKS.FOLLOW_UP_GENERATION in tasks
and tasks[TASKS.FOLLOW_UP_GENERATION]
):
res = await generate_follow_ups(
request,
{
"model": message["model"],
"messages": messages,
"message_id": metadata["message_id"],
"chat_id": metadata["chat_id"],
},
user,
)
if res and isinstance(res, dict):
if len(res.get("choices", [])) == 1:
follow_ups_string = (
res.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
else:
follow_ups_string = ""
follow_ups_string = follow_ups_string[
follow_ups_string.find("{") : follow_ups_string.rfind("}")
+ 1
]
try:
follow_ups = json.loads(follow_ups_string).get(
"follow_ups", []
)
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"followUps": follow_ups,
},
)
await event_emitter(
{
"type": "chat:message:follow_ups",
"data": {
"follow_ups": follow_ups,
},
}
)
except Exception as e:
pass
if TASKS.TITLE_GENERATION in tasks:
if tasks[TASKS.TITLE_GENERATION]:
res = await generate_title(
@@ -1162,6 +1277,7 @@ async def process_chat_response(
metadata["chat_id"],
metadata["message_id"],
{
"role": "assistant",
"content": content,
},
)
@@ -1184,8 +1300,34 @@ async def process_chat_response(
await background_tasks_handler()
if events and isinstance(events, list) and isinstance(response, dict):
extra_response = {}
for event in events:
if isinstance(event, dict):
extra_response.update(event)
else:
extra_response[event] = True
response = {
**extra_response,
**response,
}
return response
else:
if events and isinstance(events, list) and isinstance(response, dict):
extra_response = {}
for event in events:
if isinstance(event, dict):
extra_response.update(event)
else:
extra_response[event] = True
response = {
**extra_response,
**response,
}
return response
# Non standard response
@@ -1198,12 +1340,7 @@ async def process_chat_response(
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,

View File

@@ -34,11 +34,15 @@ def get_message_list(messages, message_id):
:return: List of ordered messages starting from the root to the given message
"""
# Handle case where messages is None
if not messages:
return [] # Return empty list instead of None to prevent iteration errors
# Find the message by its id
current_message = messages.get(message_id)
if not current_message:
return None
return [] # Return empty list instead of None to prevent iteration errors
# Reconstruct the chain by following the parentId links
message_list = []
@@ -47,7 +51,7 @@ def get_message_list(messages, message_id):
message_list.insert(
0, current_message
) # Insert the message at the beginning of the list
parent_id = current_message["parentId"]
parent_id = current_message.get("parentId") # Use .get() for safety
current_message = messages.get(parent_id) if parent_id else None
return message_list
@@ -70,12 +74,12 @@ def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
def get_content_from_message(message: dict) -> Optional[str]:
if isinstance(message["content"], list):
if isinstance(message.get("content"), list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
else:
return message["content"]
return message.get("content")
return None
@@ -130,7 +134,9 @@ def prepend_to_first_user_message_content(
return messages
def add_or_update_system_message(content: str, messages: list[dict]):
def add_or_update_system_message(
content: str, messages: list[dict], append: bool = False
):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
@@ -141,7 +147,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
if append:
messages[0]["content"] = f"{messages[0]['content']}\n{content}"
else:
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})

View File

@@ -1,5 +1,6 @@
import time
import logging
import asyncio
import sys
from aiocache import cached
@@ -13,7 +14,10 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.access_control import has_access
@@ -30,35 +34,46 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
async def fetch_ollama_models(request: Request, user: UserModel = None):
raw_ollama_models = await ollama.get_all_models(request, user=user)
return [
{
"id": model["model"],
"name": model["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"connection_type": model.get("connection_type", "local"),
"tags": model.get("tags", []),
}
for model in raw_ollama_models["models"]
]
async def fetch_openai_models(request: Request, user: UserModel = None):
openai_response = await openai.get_all_models(request, user=user)
return openai_response["data"]
async def get_all_base_models(request: Request, user: UserModel = None):
function_models = []
openai_models = []
ollama_models = []
openai_task = (
fetch_openai_models(request, user)
if request.app.state.config.ENABLE_OPENAI_API
else asyncio.sleep(0, result=[])
)
ollama_task = (
fetch_ollama_models(request, user)
if request.app.state.config.ENABLE_OLLAMA_API
else asyncio.sleep(0, result=[])
)
function_task = get_function_models(request)
if request.app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models(request, user=user)
openai_models = openai_models["data"]
openai_models, ollama_models, function_models = await asyncio.gather(
openai_task, ollama_task, function_task
)
if request.app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models(request, user=user)
ollama_models = [
{
"id": model["model"],
"name": model["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "ollama",
"ollama": model,
"connection_type": model.get("connection_type", "local"),
"tags": model.get("tags", []),
}
for model in ollama_models["models"]
]
function_models = await get_function_models(request)
models = function_models + openai_models + ollama_models
return models
return function_models + openai_models + ollama_models
async def get_all_models(request, user: UserModel = None):
@@ -239,8 +254,7 @@ async def get_all_models(request, user: UserModel = None):
]
def get_function_module_by_id(function_id):
function_module, _, _ = load_function_module_by_id(function_id)
request.app.state.FUNCTIONS[function_id] = function_module
function_module, _, _ = get_function_module_from_cache(request, function_id)
return function_module
for model in models:

View File

@@ -536,5 +536,10 @@ class OAuthManager:
secure=WEBUI_AUTH_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
redirect_base_url = redirect_base_url[:-1]
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
return RedirectResponse(url=redirect_url, headers=response.headers)

View File

@@ -1,5 +1,6 @@
from open_webui.utils.task import prompt_template, prompt_variables_template
from open_webui.utils.misc import (
deep_update,
add_or_update_system_message,
)
@@ -9,9 +10,8 @@ import json
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
) -> dict:
system = params.get("system", None)
if not system:
return form_data
@@ -45,18 +45,64 @@ def apply_model_params_to_body(
if not params:
return form_data
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
for key, value in params.items():
if value is not None:
if key in mappings:
cast_func = mappings[key]
if isinstance(cast_func, Callable):
form_data[key] = cast_func(value)
else:
form_data[key] = value
return form_data
def remove_open_webui_params(params: dict) -> dict:
"""
Removes OpenWebUI specific parameters from the provided dictionary.
Args:
params (dict): The dictionary containing parameters.
Returns:
dict: The modified dictionary with OpenWebUI parameters removed.
"""
open_webui_params = {
"stream_response": bool,
"function_calling": str,
"system": str,
}
for key in list(params.keys()):
if key in open_webui_params:
del params[key]
return params
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
params = remove_open_webui_params(params)
custom_params = params.pop("custom_params", {})
if custom_params:
# Attempt to parse custom_params if they are strings
for key, value in custom_params.items():
if isinstance(value, str):
try:
# Attempt to parse the string as JSON
custom_params[key] = json.loads(value)
except json.JSONDecodeError:
# If it fails, keep the original string
pass
# If there are custom parameters, we need to apply them first
params = deep_update(params, custom_params)
mappings = {
"temperature": float,
"top_p": float,
"min_p": float,
"max_tokens": int,
"frequency_penalty": float,
"presence_penalty": float,
@@ -70,6 +116,23 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
params = remove_open_webui_params(params)
custom_params = params.pop("custom_params", {})
if custom_params:
# Attempt to parse custom_params if they are strings
for key, value in custom_params.items():
if isinstance(value, str):
try:
# Attempt to parse the string as JSON
custom_params[key] = json.loads(value)
except json.JSONDecodeError:
# If it fails, keep the original string
pass
# If there are custom parameters, we need to apply them first
params = deep_update(params, custom_params)
# Convert OpenAI parameter names to Ollama parameter names if needed.
name_differences = {
"max_tokens": "num_predict",
@@ -266,3 +329,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
ollama_payload["format"] = format
return ollama_payload
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
"""
Convert an embeddings request payload from OpenAI format to Ollama format.
Args:
openai_payload (dict): The original payload designed for OpenAI API usage.
Returns:
dict: A payload compatible with the Ollama API embeddings endpoint.
"""
ollama_payload = {"model": openai_payload.get("model")}
input_value = openai_payload.get("input")
# Ollama expects 'input' as a list, and 'prompt' as a single string.
if isinstance(input_value, list):
ollama_payload["input"] = input_value
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
else:
ollama_payload["input"] = [input_value]
ollama_payload["prompt"] = str(input_value)
# Optionally forward other fields if present
for optional_key in ("options", "truncate", "keep_alive"):
if optional_key in openai_payload:
ollama_payload[optional_key] = openai_payload[optional_key]
return ollama_payload

View File

@@ -115,7 +115,7 @@ def load_tool_module_by_id(tool_id, content=None):
os.unlink(temp_file.name)
def load_function_module_by_id(function_id, content=None):
def load_function_module_by_id(function_id: str, content: str | None = None):
if content is None:
function = Functions.get_function_by_id(function_id)
if not function:
@@ -166,6 +166,62 @@ def load_function_module_by_id(function_id, content=None):
os.unlink(temp_file.name)
def get_function_module_from_cache(request, function_id, load_from_db=True):
if load_from_db:
# Always load from the database by default
# This is useful for hooks like "inlet" or "outlet" where the content might change
# and we want to ensure the latest content is used.
function = Functions.get_function_by_id(function_id)
if not function:
raise Exception(f"Function not found: {function_id}")
content = function.content
new_content = replace_imports(content)
if new_content != content:
content = new_content
# Update the function content in the database
Functions.update_function_by_id(function_id, {"content": content})
if (
hasattr(request.app.state, "FUNCTION_CONTENTS")
and function_id in request.app.state.FUNCTION_CONTENTS
) and (
hasattr(request.app.state, "FUNCTIONS")
and function_id in request.app.state.FUNCTIONS
):
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id(
function_id, content
)
else:
# Load from cache (e.g. "stream" hook)
# This is useful for performance reasons
if (
hasattr(request.app.state, "FUNCTIONS")
and function_id in request.app.state.FUNCTIONS
):
return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id(
function_id
)
if not hasattr(request.app.state, "FUNCTIONS"):
request.app.state.FUNCTIONS = {}
if not hasattr(request.app.state, "FUNCTION_CONTENTS"):
request.app.state.FUNCTION_CONTENTS = {}
request.app.state.FUNCTIONS[function_id] = function_module
request.app.state.FUNCTION_CONTENTS[function_id] = content
return function_module, function_type, frontmatter
def install_frontmatter_requirements(requirements: str):
if requirements:
try:

View File

@@ -125,3 +125,64 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
yield line
yield "data: [DONE]\n\n"
def convert_embedding_response_ollama_to_openai(response) -> dict:
"""
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
Args:
response (dict): The response from the Ollama API,
e.g. {"embedding": [...], "model": "..."}
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
Returns:
dict: Response adapted to OpenAI's embeddings API format.
e.g. {
"object": "list",
"data": [
{"object": "embedding", "embedding": [...], "index": 0},
...
],
"model": "...",
}
"""
# Ollama batch-style output
if isinstance(response, dict) and "embeddings" in response:
openai_data = []
for i, emb in enumerate(response["embeddings"]):
openai_data.append(
{
"object": "embedding",
"embedding": emb.get("embedding"),
"index": emb.get("index", i),
}
)
return {
"object": "list",
"data": openai_data,
"model": response.get("model"),
}
# Ollama single output
elif isinstance(response, dict) and "embedding" in response:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": response["embedding"],
"index": 0,
}
],
"model": response.get("model"),
}
# Already OpenAI-compatible?
elif (
isinstance(response, dict)
and "data" in response
and isinstance(response["data"], list)
):
return response
# Fallback: return as is if unrecognized
return response

View File

@@ -22,7 +22,7 @@ def get_task_model_id(
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id].get("owned_by") == "ollama":
if models[task_model_id].get("connection_type") == "local":
if task_model and task_model in models:
task_model_id = task_model
else:
@@ -207,6 +207,24 @@ def title_generation_template(
return template
def follow_up_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:
prompt = get_last_user_message(messages)
template = replace_prompt_variable(template, prompt)
template = replace_messages_variable(template, messages)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "user_location": user.get("location")}
if user
else {}
),
)
return template
def tags_generation_template(
template: str, messages: list[dict], user: Optional[dict] = None
) -> str:

View File

@@ -160,7 +160,7 @@ def get_tools(
# TODO: Fix hack for OpenAI API
# Some times breaks OpenAI but others don't. Leaving the comment
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
if val.get("type") == "str":
val["type"] = "string"
# Remove internal reserved parameters (e.g. __id__, __user__)
@@ -490,8 +490,19 @@ async def get_tool_servers_data(
server_entries = []
for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"):
url_path = server.get("path", "openapi.json")
full_url = f"{server.get('url')}/{url_path}"
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
openapi_path = server.get("path", "openapi.json")
if "://" in openapi_path:
# If it contains "://", it's a full URL
full_url = openapi_path
else:
if not openapi_path.startswith("/"):
# Ensure the path starts with a slash
openapi_path = f"/{openapi_path}"
full_url = f"{server.get('url')}{openapi_path}"
info = server.get("info", {})
auth_type = server.get("auth_type", "bearer")
token = None
@@ -500,26 +511,37 @@ async def get_tool_servers_data(
token = server.get("key", "")
elif auth_type == "session":
token = session_token
server_entries.append((idx, server, full_url, token))
server_entries.append((idx, server, full_url, info, token))
# Create async tasks to fetch data
tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
tasks = [
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
]
# Execute tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Build final results with index and server metadata
results = []
for (idx, server, url, _), response in zip(server_entries, responses):
for (idx, server, url, info, _), response in zip(server_entries, responses):
if isinstance(response, Exception):
log.error(f"Failed to connect to {url} OpenAPI tool server")
continue
openapi_data = response.get("openapi", {})
if info and isinstance(openapi_data, dict):
if "name" in info:
openapi_data["info"]["title"] = info.get("name", "Tool Server")
if "description" in info:
openapi_data["info"]["description"] = info.get("description", "")
results.append(
{
"idx": idx,
"url": server.get("url"),
"openapi": response.get("openapi"),
"openapi": openapi_data,
"info": response.get("info"),
"specs": response.get("specs"),
}