mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into next
This commit is contained in:
@@ -347,6 +347,24 @@ MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
|
||||
os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
|
||||
)
|
||||
|
||||
MICROSOFT_CLIENT_LOGIN_BASE_URL = PersistentConfig(
|
||||
"MICROSOFT_CLIENT_LOGIN_BASE_URL",
|
||||
"oauth.microsoft.login_base_url",
|
||||
os.environ.get(
|
||||
"MICROSOFT_CLIENT_LOGIN_BASE_URL", "https://login.microsoftonline.com"
|
||||
),
|
||||
)
|
||||
|
||||
MICROSOFT_CLIENT_PICTURE_URL = PersistentConfig(
|
||||
"MICROSOFT_CLIENT_PICTURE_URL",
|
||||
"oauth.microsoft.picture_url",
|
||||
os.environ.get(
|
||||
"MICROSOFT_CLIENT_PICTURE_URL",
|
||||
"https://graph.microsoft.com/v1.0/me/photo/$value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
MICROSOFT_OAUTH_SCOPE = PersistentConfig(
|
||||
"MICROSOFT_OAUTH_SCOPE",
|
||||
"oauth.microsoft.scope",
|
||||
@@ -542,7 +560,7 @@ def load_oauth_providers():
|
||||
name="microsoft",
|
||||
client_id=MICROSOFT_CLIENT_ID.value,
|
||||
client_secret=MICROSOFT_CLIENT_SECRET.value,
|
||||
server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
|
||||
server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
|
||||
client_kwargs={
|
||||
"scope": MICROSOFT_OAUTH_SCOPE.value,
|
||||
},
|
||||
@@ -551,7 +569,7 @@ def load_oauth_providers():
|
||||
|
||||
OAUTH_PROVIDERS["microsoft"] = {
|
||||
"redirect_uri": MICROSOFT_REDIRECT_URI.value,
|
||||
"picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value",
|
||||
"picture_url": MICROSOFT_CLIENT_PICTURE_URL.value,
|
||||
"register": microsoft_oauth_register,
|
||||
}
|
||||
|
||||
@@ -901,9 +919,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig(
|
||||
####################################
|
||||
|
||||
|
||||
WEBUI_URL = PersistentConfig(
|
||||
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
)
|
||||
WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", ""))
|
||||
|
||||
|
||||
ENABLE_SIGNUP = PersistentConfig(
|
||||
@@ -1247,12 +1263,6 @@ if THREAD_POOL_SIZE is not None and isinstance(THREAD_POOL_SIZE, str):
|
||||
THREAD_POOL_SIZE = None
|
||||
|
||||
|
||||
def validate_cors_origins(origins):
|
||||
for origin in origins:
|
||||
if origin != "*":
|
||||
validate_cors_origin(origin)
|
||||
|
||||
|
||||
def validate_cors_origin(origin):
|
||||
parsed_url = urlparse(origin)
|
||||
|
||||
@@ -1272,16 +1282,17 @@ def validate_cors_origin(origin):
|
||||
# To test CORS_ALLOW_ORIGIN locally, you can set something like
|
||||
# CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080
|
||||
# in your .env file depending on your frontend port, 5173 in this case.
|
||||
CORS_ALLOW_ORIGIN = os.environ.get(
|
||||
"CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080"
|
||||
).split(";")
|
||||
CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";")
|
||||
|
||||
if "*" in CORS_ALLOW_ORIGIN:
|
||||
if CORS_ALLOW_ORIGIN == ["*"]:
|
||||
log.warning(
|
||||
"\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n"
|
||||
)
|
||||
|
||||
validate_cors_origins(CORS_ALLOW_ORIGIN)
|
||||
else:
|
||||
# You have to pick between a single wildcard or a list of origins.
|
||||
# Doing both will result in CORS errors in the browser.
|
||||
for origin in CORS_ALLOW_ORIGIN:
|
||||
validate_cors_origin(origin)
|
||||
|
||||
|
||||
class BannerModel(BaseModel):
|
||||
@@ -1413,6 +1424,35 @@ Strictly return in JSON format:
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.follow_up.prompt_template",
|
||||
os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
|
||||
### Guidelines:
|
||||
- Write all follow-up questions from the user’s point of view, directed to the assistant.
|
||||
- Make questions concise, clear, and directly related to the discussed topic(s).
|
||||
- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
|
||||
- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
|
||||
- Use the conversation's primary language; default to English if multilingual.
|
||||
- Response must be a JSON array of strings, no extra text or formatting.
|
||||
### Output:
|
||||
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
|
||||
"ENABLE_FOLLOW_UP_GENERATION",
|
||||
"task.follow_up.enable",
|
||||
os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||
"ENABLE_TAGS_GENERATION",
|
||||
"task.tags.enable",
|
||||
@@ -1786,6 +1826,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||
)
|
||||
|
||||
PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
|
||||
PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
|
||||
if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
|
||||
)
|
||||
|
||||
# Pinecone
|
||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||
@@ -1946,6 +1993,40 @@ DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
|
||||
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
DOCLING_PICTURE_DESCRIPTION_MODE = PersistentConfig(
|
||||
"DOCLING_PICTURE_DESCRIPTION_MODE",
|
||||
"rag.docling_picture_description_mode",
|
||||
os.getenv("DOCLING_PICTURE_DESCRIPTION_MODE", ""),
|
||||
)
|
||||
|
||||
|
||||
docling_picture_description_local = os.getenv("DOCLING_PICTURE_DESCRIPTION_LOCAL", "")
|
||||
try:
|
||||
docling_picture_description_local = json.loads(docling_picture_description_local)
|
||||
except json.JSONDecodeError:
|
||||
docling_picture_description_local = {}
|
||||
|
||||
|
||||
DOCLING_PICTURE_DESCRIPTION_LOCAL = PersistentConfig(
|
||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL",
|
||||
"rag.docling_picture_description_local",
|
||||
docling_picture_description_local,
|
||||
)
|
||||
|
||||
docling_picture_description_api = os.getenv("DOCLING_PICTURE_DESCRIPTION_API", "")
|
||||
try:
|
||||
docling_picture_description_api = json.loads(docling_picture_description_api)
|
||||
except json.JSONDecodeError:
|
||||
docling_picture_description_api = {}
|
||||
|
||||
|
||||
DOCLING_PICTURE_DESCRIPTION_API = PersistentConfig(
|
||||
"DOCLING_PICTURE_DESCRIPTION_API",
|
||||
"rag.docling_picture_description_api",
|
||||
docling_picture_description_api,
|
||||
)
|
||||
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
@@ -2445,6 +2526,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
|
||||
os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
)
|
||||
|
||||
PERPLEXITY_MODEL = PersistentConfig(
|
||||
"PERPLEXITY_MODEL",
|
||||
"rag.web.search.perplexity_model",
|
||||
os.getenv("PERPLEXITY_MODEL", "sonar"),
|
||||
)
|
||||
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE",
|
||||
"rag.web.search.perplexity_search_context_usage",
|
||||
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
|
||||
)
|
||||
|
||||
SOUGOU_API_SID = PersistentConfig(
|
||||
"SOUGOU_API_SID",
|
||||
"rag.web.search.sougou_api_sid",
|
||||
|
||||
@@ -111,6 +111,7 @@ class TASKS(str, Enum):
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
FOLLOW_UP_GENERATION = "follow_up_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import pkgutil
|
||||
import sys
|
||||
import shutil
|
||||
from uuid import uuid4
|
||||
from pathlib import Path
|
||||
|
||||
import markdown
|
||||
@@ -130,6 +131,7 @@ else:
|
||||
PACKAGE_DATA = {"version": "0.0.0"}
|
||||
|
||||
VERSION = PACKAGE_DATA["version"]
|
||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4()))
|
||||
|
||||
|
||||
# Function to parse each section
|
||||
|
||||
@@ -25,6 +25,7 @@ from open_webui.socket.main import (
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
@@ -227,12 +228,7 @@ async def generate_function_chat_completion(
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from urllib.parse import urlencode, parse_qs, urlparse
|
||||
@@ -19,6 +21,7 @@ from aiocache import cached
|
||||
import aiohttp
|
||||
import anyio.to_thread
|
||||
import requests
|
||||
from redis import Redis
|
||||
|
||||
|
||||
from fastapi import (
|
||||
@@ -37,7 +40,7 @@ from fastapi import (
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from starlette_compress import CompressMiddleware
|
||||
@@ -231,6 +234,9 @@ from open_webui.config import (
|
||||
DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
DOCLING_PICTURE_DESCRIPTION_API,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY,
|
||||
@@ -268,6 +274,8 @@ from open_webui.config import (
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
EXA_API_KEY,
|
||||
PERPLEXITY_API_KEY,
|
||||
PERPLEXITY_MODEL,
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
SOUGOU_API_SID,
|
||||
SOUGOU_API_SK,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
@@ -359,10 +367,12 @@ from open_webui.config import (
|
||||
TASK_MODEL_EXTERNAL,
|
||||
ENABLE_TAGS_GENERATION,
|
||||
ENABLE_TITLE_GENERATION,
|
||||
ENABLE_FOLLOW_UP_GENERATION,
|
||||
ENABLE_SEARCH_QUERY_GENERATION,
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
@@ -384,6 +394,7 @@ from open_webui.env import (
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
INSTANCE_ID,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
@@ -411,6 +422,7 @@ from open_webui.utils.chat import (
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@@ -424,8 +436,10 @@ from open_webui.utils.auth import (
|
||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||
from open_webui.utils.oauth import OAuthManager
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
|
||||
from open_webui.tasks import (
|
||||
redis_task_command_listener,
|
||||
list_task_ids_by_chat_id,
|
||||
stop_task,
|
||||
list_tasks,
|
||||
@@ -477,7 +491,9 @@ https://github.com/open-webui/open-webui
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.instance_id = INSTANCE_ID
|
||||
start_logger()
|
||||
|
||||
if RESET_CONFIG_ON_START:
|
||||
reset_config()
|
||||
|
||||
@@ -489,6 +505,19 @@ async def lifespan(app: FastAPI):
|
||||
log.info("Installing external dependencies of functions and tools...")
|
||||
install_tool_and_function_dependencies()
|
||||
|
||||
app.state.redis = get_redis_connection(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
if app.state.redis is not None:
|
||||
app.state.redis_task_command_listener = asyncio.create_task(
|
||||
redis_task_command_listener(app)
|
||||
)
|
||||
|
||||
if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0:
|
||||
limiter = anyio.to_thread.current_default_thread_limiter()
|
||||
limiter.total_tokens = THREAD_POOL_SIZE
|
||||
@@ -497,6 +526,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(app.state, "redis_task_command_listener"):
|
||||
app.state.redis_task_command_listener.cancel()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Open WebUI",
|
||||
@@ -508,10 +540,12 @@ app = FastAPI(
|
||||
|
||||
oauth_manager = OAuthManager(app)
|
||||
|
||||
app.state.instance_id = None
|
||||
app.state.config = AppConfig(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
||||
)
|
||||
app.state.redis = None
|
||||
|
||||
app.state.WEBUI_NAME = WEBUI_NAME
|
||||
app.state.LICENSE_METADATA = None
|
||||
@@ -696,6 +730,9 @@ app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
|
||||
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
|
||||
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
||||
@@ -771,6 +808,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
||||
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
||||
@@ -959,6 +998,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
|
||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
|
||||
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
|
||||
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
@@ -966,6 +1006,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
|
||||
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
@@ -1197,6 +1240,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
##################################
|
||||
# Embeddings
|
||||
##################################
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def embeddings(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible embeddings endpoint.
|
||||
|
||||
This handler:
|
||||
- Performs user/model checks and dispatches to the correct backend.
|
||||
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
||||
|
||||
Args:
|
||||
request (Request): Request context.
|
||||
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
||||
user (UserModel): Authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
# Make sure models are loaded in app state
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
# Use generic dispatcher in utils.embeddings
|
||||
return await generate_embeddings(request, form_data, user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
@@ -1338,26 +1412,30 @@ async def chat_action(
|
||||
|
||||
|
||||
@app.post("/api/tasks/stop/{task_id}")
|
||||
async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
|
||||
async def stop_task_endpoint(
|
||||
request: Request, task_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
result = await stop_task(task_id)
|
||||
result = await stop_task(request, task_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/tasks")
|
||||
async def list_tasks_endpoint(user=Depends(get_verified_user)):
|
||||
return {"tasks": list_tasks()}
|
||||
async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)):
|
||||
return {"tasks": await list_tasks(request)}
|
||||
|
||||
|
||||
@app.get("/api/tasks/chat/{chat_id}")
|
||||
async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)):
|
||||
async def list_tasks_by_chat_id_endpoint(
|
||||
request: Request, chat_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
chat = Chats.get_chat_by_id(chat_id)
|
||||
if chat is None or chat.user_id != user.id:
|
||||
return {"task_ids": []}
|
||||
|
||||
task_ids = list_task_ids_by_chat_id(chat_id)
|
||||
task_ids = await list_task_ids_by_chat_id(request, chat_id)
|
||||
|
||||
print(f"Task IDs for chat {chat_id}: {task_ids}")
|
||||
return {"task_ids": task_ids}
|
||||
@@ -1628,7 +1706,20 @@ async def healthcheck_with_db():
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
|
||||
@app.get("/cache/{path:path}")
|
||||
async def serve_cache_file(
|
||||
path: str,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
|
||||
# prevent path traversal
|
||||
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def swagger_ui_html(*args, **kwargs):
|
||||
|
||||
@@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class UserUpdateForm(BaseModel):
|
||||
role: str
|
||||
name: str
|
||||
email: str
|
||||
profile_image_url: str
|
||||
@@ -369,7 +370,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str) -> bool:
|
||||
try:
|
||||
with get_db() as db:
|
||||
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
|
||||
|
||||
@@ -2,6 +2,7 @@ import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
import json
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
@@ -76,7 +77,6 @@ known_source_ext = [
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"msg",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
@@ -147,17 +147,32 @@ class DoclingLoader:
|
||||
)
|
||||
}
|
||||
|
||||
params = {
|
||||
"image_export_mode": "placeholder",
|
||||
"table_mode": "accurate",
|
||||
}
|
||||
params = {"image_export_mode": "placeholder", "table_mode": "accurate"}
|
||||
|
||||
if self.params:
|
||||
if self.params.get("do_picture_classification"):
|
||||
params["do_picture_classification"] = self.params.get(
|
||||
"do_picture_classification"
|
||||
if self.params.get("do_picture_description"):
|
||||
params["do_picture_description"] = self.params.get(
|
||||
"do_picture_description"
|
||||
)
|
||||
|
||||
picture_description_mode = self.params.get(
|
||||
"picture_description_mode", ""
|
||||
).lower()
|
||||
|
||||
if picture_description_mode == "local" and self.params.get(
|
||||
"picture_description_local", {}
|
||||
):
|
||||
params["picture_description_local"] = self.params.get(
|
||||
"picture_description_local", {}
|
||||
)
|
||||
|
||||
elif picture_description_mode == "api" and self.params.get(
|
||||
"picture_description_api", {}
|
||||
):
|
||||
params["picture_description_api"] = self.params.get(
|
||||
"picture_description_api", {}
|
||||
)
|
||||
|
||||
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
|
||||
params["ocr_engine"] = self.params.get("ocr_engine")
|
||||
params["ocr_lang"] = [
|
||||
@@ -285,17 +300,20 @@ class Loader:
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# Build params for DoclingLoader
|
||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
||||
if not isinstance(params, dict):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
||||
params = {}
|
||||
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params={
|
||||
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
"do_picture_classification": self.kwargs.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION"
|
||||
),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
|
||||
@@ -20,6 +20,14 @@ class MistralLoader:
|
||||
"""
|
||||
Enhanced Mistral OCR loader with both sync and async support.
|
||||
Loads documents by processing them through the Mistral OCR API.
|
||||
|
||||
Performance Optimizations:
|
||||
- Differentiated timeouts for different operations
|
||||
- Intelligent retry logic with exponential backoff
|
||||
- Memory-efficient file streaming for large files
|
||||
- Connection pooling and keepalive optimization
|
||||
- Semaphore-based concurrency control for batch processing
|
||||
- Enhanced error handling with retryable error classification
|
||||
"""
|
||||
|
||||
BASE_API_URL = "https://api.mistral.ai/v1"
|
||||
@@ -53,17 +61,40 @@ class MistralLoader:
|
||||
self.max_retries = max_retries
|
||||
self.debug = enable_debug_logging
|
||||
|
||||
# Pre-compute file info for performance
|
||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||
# This prevents long-running OCR operations from affecting quick operations
|
||||
# and improves user experience by failing fast on operations that should be quick
|
||||
self.upload_timeout = min(
|
||||
timeout, 120
|
||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = (
|
||||
30 # URL requests should be fast - fail quickly if API is slow
|
||||
)
|
||||
self.ocr_timeout = (
|
||||
timeout # OCR can take the full timeout - this is the heavy operation
|
||||
)
|
||||
self.cleanup_timeout = (
|
||||
30 # Cleanup should be quick - don't hang on file deletion
|
||||
)
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||
self.file_name = os.path.basename(file_path)
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
|
||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
"""Conditional debug logging for performance."""
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
|
||||
|
||||
Only processes debug messages when debug mode is enabled, avoiding
|
||||
string formatting overhead in production environments.
|
||||
"""
|
||||
if self.debug:
|
||||
log.debug(message, *args)
|
||||
|
||||
@@ -115,53 +146,118 @@ class MistralLoader:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
raise
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
ENHANCEMENT: Intelligent error classification for retry logic.
|
||||
|
||||
Determines if an error is retryable based on its type and status code.
|
||||
This prevents wasting time retrying errors that will never succeed
|
||||
(like authentication errors) while ensuring transient errors are retried.
|
||||
|
||||
Retryable errors:
|
||||
- Network connection errors (temporary network issues)
|
||||
- Timeouts (server might be temporarily overloaded)
|
||||
- Server errors (5xx status codes - server-side issues)
|
||||
- Rate limiting (429 status - temporary throttling)
|
||||
|
||||
Non-retryable errors:
|
||||
- Authentication errors (401, 403 - won't fix with retry)
|
||||
- Bad request errors (400 - malformed request)
|
||||
- Not found errors (404 - resource doesn't exist)
|
||||
"""
|
||||
if isinstance(error, requests.exceptions.ConnectionError):
|
||||
return True # Network issues are usually temporary
|
||||
if isinstance(error, requests.exceptions.Timeout):
|
||||
return True # Timeouts might resolve on retry
|
||||
if isinstance(error, requests.exceptions.HTTPError):
|
||||
# Only retry on server errors (5xx) or rate limits (429)
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
return status_code >= 500 or status_code == 429
|
||||
return False
|
||||
if isinstance(
|
||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
||||
):
|
||||
return True # Async network/timeout errors are retryable
|
||||
if isinstance(error, aiohttp.ClientResponseError):
|
||||
return error.status >= 500 or error.status == 429
|
||||
return False # All other errors are non-retryable
|
||||
|
||||
def _retry_request_sync(self, request_func, *args, **kwargs):
|
||||
"""Synchronous retry logic with exponential backoff."""
|
||||
"""
|
||||
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
|
||||
|
||||
Uses exponential backoff with jitter to avoid thundering herd problems.
|
||||
The wait time increases exponentially but is capped at 30 seconds to
|
||||
prevent excessive delays. Only retries errors that are likely to succeed
|
||||
on subsequent attempts.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return request_func(*args, **kwargs)
|
||||
except (requests.exceptions.RequestException, Exception) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
|
||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
async def _retry_request_async(self, request_func, *args, **kwargs):
|
||||
"""Async retry logic with exponential backoff."""
|
||||
"""
|
||||
ENHANCEMENT: Async retry logic with intelligent error classification.
|
||||
|
||||
Async version of retry logic that doesn't block the event loop during
|
||||
wait periods. Uses the same exponential backoff strategy as sync version.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await request_func(*args, **kwargs)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""Uploads the file to Mistral for OCR processing (sync version)."""
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
|
||||
|
||||
Uploads the file to Mistral for OCR processing (sync version).
|
||||
Uses context manager for file handling to ensure proper resource cleanup.
|
||||
Although streaming is not enabled for this endpoint, the file is opened
|
||||
in a context manager to minimize memory usage duration.
|
||||
"""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
file_name = os.path.basename(self.file_path)
|
||||
|
||||
def upload_request():
|
||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (file_name, f, "application/pdf")}
|
||||
files = {"file": (self.file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
|
||||
# NOTE: stream=False is required for this endpoint
|
||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.timeout,
|
||||
timeout=self.upload_timeout, # Use specialized upload timeout
|
||||
stream=False, # Keep as False for this endpoint
|
||||
)
|
||||
|
||||
return self._handle_response(response)
|
||||
@@ -209,7 +305,7 @@ class MistralLoader:
|
||||
url,
|
||||
data=writer,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
@@ -231,7 +327,7 @@ class MistralLoader:
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.timeout
|
||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
@@ -261,7 +357,7 @@ class MistralLoader:
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
@@ -294,7 +390,7 @@ class MistralLoader:
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.timeout
|
||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
@@ -336,7 +432,7 @@ class MistralLoader:
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
|
||||
) as response:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
@@ -353,7 +449,9 @@ class MistralLoader:
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}"
|
||||
|
||||
try:
|
||||
response = requests.delete(url, headers=self.headers, timeout=30)
|
||||
response = requests.delete(
|
||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
||||
)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
except Exception as e:
|
||||
@@ -372,7 +470,7 @@ class MistralLoader:
|
||||
url=f"{self.BASE_API_URL}/files/{file_id}",
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=30
|
||||
total=self.cleanup_timeout
|
||||
), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
@@ -388,29 +486,39 @@ class MistralLoader:
|
||||
async def _get_session(self):
|
||||
"""Context manager for HTTP session with optimized settings."""
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=10, # Total connection limit
|
||||
limit_per_host=5, # Per-host connection limit
|
||||
ttl_dns_cache=300, # DNS cache TTL
|
||||
limit=20, # Increased total connection limit for better throughput
|
||||
limit_per_host=10, # Increased per-host limit for API endpoints
|
||||
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=30,
|
||||
keepalive_timeout=60, # Increased keepalive for connection reuse
|
||||
enable_cleanup_closed=True,
|
||||
force_close=False, # Allow connection reuse
|
||||
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.timeout,
|
||||
connect=30, # Connection timeout
|
||||
sock_read=60, # Socket read timeout
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
raise_for_status=False, # We handle status codes manually
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata."""
|
||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found", metadata={"error": "no_pages"}
|
||||
page_content="No text content found",
|
||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -418,41 +526,44 @@ class MistralLoader:
|
||||
total_pages = len(pages_data)
|
||||
skipped_pages = 0
|
||||
|
||||
# Process pages in a memory-efficient way
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is not None and page_index is not None:
|
||||
# Clean up content efficiently
|
||||
cleaned_content = (
|
||||
page_content.strip()
|
||||
if isinstance(page_content, str)
|
||||
else str(page_content)
|
||||
)
|
||||
|
||||
if cleaned_content: # Only add non-empty pages
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index
|
||||
+ 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
else:
|
||||
if page_content is None or page_index is None:
|
||||
skipped_pages += 1
|
||||
self._debug_log(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Clean up content efficiently with early exit for empty content
|
||||
if isinstance(page_content, str):
|
||||
cleaned_content = page_content.strip()
|
||||
else:
|
||||
cleaned_content = str(page_content).strip()
|
||||
|
||||
if not cleaned_content:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
continue
|
||||
|
||||
# Create document with optimized metadata
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index + 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
"content_length": len(cleaned_content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
@@ -467,7 +578,11 @@ class MistralLoader:
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
metadata={"error": "no_valid_pages", "total_pages": total_pages},
|
||||
metadata={
|
||||
"error": "no_valid_pages",
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -585,12 +700,14 @@ class MistralLoader:
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
max_concurrent: int = 5, # Limit concurrent requests
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Process multiple files concurrently for maximum performance.
|
||||
Process multiple files concurrently with controlled concurrency.
|
||||
|
||||
Args:
|
||||
loaders: List of MistralLoader instances
|
||||
max_concurrent: Maximum number of concurrent requests
|
||||
|
||||
Returns:
|
||||
List of document lists, one for each loader
|
||||
@@ -598,11 +715,20 @@ class MistralLoader:
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(f"Starting concurrent processing of {len(loaders)} files")
|
||||
log.info(
|
||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
# Process all files concurrently
|
||||
tasks = [loader.load_async() for loader in loaders]
|
||||
# Use semaphore to control concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
||||
async with semaphore:
|
||||
return await loader.load_async()
|
||||
|
||||
# Process all files with controlled concurrency
|
||||
tasks = [process_with_semaphore(loader) for loader in loaders]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in results
|
||||
@@ -624,10 +750,18 @@ class MistralLoader:
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
# MONITORING: Log comprehensive batch processing statistics
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
|
||||
f"Batch processing completed in {total_time:.2f}s: "
|
||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
||||
f"produced {total_docs} total documents"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
@@ -93,7 +94,6 @@ class YoutubeLoader:
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
# Don't log complete URL because it might contain secrets
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
else:
|
||||
youtube_proxies = None
|
||||
@@ -110,11 +110,37 @@ class YoutubeLoader:
|
||||
for lang in self.language:
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([lang])
|
||||
if transcript.is_generated:
|
||||
log.debug(f"Found generated transcript for language '{lang}'")
|
||||
try:
|
||||
transcript = transcript_list.find_manually_created_transcript(
|
||||
[lang]
|
||||
)
|
||||
log.debug(f"Found manual transcript for language '{lang}'")
|
||||
except NoTranscriptFound:
|
||||
log.debug(
|
||||
f"No manual transcript found for language '{lang}', using generated"
|
||||
)
|
||||
pass
|
||||
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
try:
|
||||
transcript_pieces: List[Dict[str, Any]] = transcript.fetch()
|
||||
except ParseError:
|
||||
log.debug(f"Empty or invalid transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
if not transcript_pieces:
|
||||
log.debug(f"Empty transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
transcript_text = " ".join(
|
||||
map(
|
||||
lambda transcript_piece: transcript_piece.text.strip(" "),
|
||||
lambda transcript_piece: (
|
||||
transcript_piece.text.strip(" ")
|
||||
if hasattr(transcript_piece, "text")
|
||||
else ""
|
||||
),
|
||||
transcript_pieces,
|
||||
)
|
||||
)
|
||||
@@ -131,6 +157,4 @@ class YoutubeLoader:
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(
|
||||
f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
import json
|
||||
from sqlalchemy import (
|
||||
func,
|
||||
literal,
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
LargeBinary,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
@@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
from open_webui.config import (
|
||||
PGVECTOR_DB_URL,
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
PGVECTOR_PGCRYPTO,
|
||||
PGVECTOR_PGCRYPTO_KEY,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def pgcrypto_encrypt(val, key):
|
||||
return func.pgp_sym_encrypt(val, literal(key))
|
||||
|
||||
|
||||
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
text = Column(LargeBinary, nullable=True)
|
||||
vmetadata = Column(LargeBinary, nullable=True)
|
||||
else:
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
||||
class PgvectorClient(VectorDBBase):
|
||||
@@ -70,6 +92,15 @@ class PgvectorClient(VectorDBBase):
|
||||
# Ensure the pgvector extension is available
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Ensure the pgcrypto extension is available for encryption
|
||||
self.session.execute(text("CREATE EXTENSION IF NOT EXISTS pgcrypto;"))
|
||||
|
||||
if not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
|
||||
)
|
||||
|
||||
# Check vector length consistency
|
||||
self.check_vector_length()
|
||||
|
||||
@@ -147,44 +178,39 @@ class PgvectorClient(VectorDBBase):
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
||||
|
||||
else:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
@@ -192,11 +218,78 @@ class PgvectorClient(VectorDBBase):
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
vector = EXCLUDED.vector,
|
||||
collection_name = EXCLUDED.collection_name,
|
||||
text = EXCLUDED.text,
|
||||
vmetadata = EXCLUDED.vmetadata
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||
else:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = item["metadata"]
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=item["metadata"],
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
@@ -230,16 +323,32 @@ class PgvectorClient(VectorDBBase):
|
||||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text")
|
||||
)
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata")
|
||||
)
|
||||
else:
|
||||
result_fields.append(DocumentChunk.text)
|
||||
result_fields.append(DocumentChunk.vmetadata)
|
||||
result_fields.append(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
)
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
@@ -299,17 +408,43 @@ class PgvectorClient(VectorDBBase):
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Build where clause for vmetadata filter
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
for key, value in filter.items():
|
||||
# decrypt then check key: JSON filter after decryption
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(*where_clauses)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
results = query.all()
|
||||
results = query.all()
|
||||
|
||||
if not results:
|
||||
return None
|
||||
@@ -331,20 +466,38 @@ class PgvectorClient(VectorDBBase):
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
ids = [[row.id for row in results]]
|
||||
documents = [[row.text for row in results]]
|
||||
metadatas = [[row.vmetadata for row in results]]
|
||||
else:
|
||||
|
||||
results = query.all()
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
results = query.all()
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
if not results:
|
||||
return None
|
||||
|
||||
ids = [[result.id for result in results]]
|
||||
documents = [[result.text for result in results]]
|
||||
metadatas = [[result.vmetadata for result in results]]
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
@@ -358,17 +511,33 @@ class PgvectorClient(VectorDBBase):
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
wheres = [DocumentChunk.collection_name == collection_name]
|
||||
if ids:
|
||||
wheres.append(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
wheres.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||
result = self.session.execute(stmt)
|
||||
deleted = result.rowcount
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,10 +3,19 @@ import logging
|
||||
import time # for measuring elapsed time
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
# Add gRPC support for better performance (Pinecone best practice)
|
||||
try:
|
||||
from pinecone.grpc import PineconeGRPC
|
||||
|
||||
GRPC_AVAILABLE = True
|
||||
except ImportError:
|
||||
GRPC_AVAILABLE = False
|
||||
|
||||
import asyncio # for async upserts
|
||||
import functools # for partial binding in async tasks
|
||||
|
||||
import concurrent.futures # for parallel batch upserts
|
||||
import random # for jitter in retry backoff
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
@@ -47,7 +56,24 @@ class PineconeClient(VectorDBBase):
|
||||
self.cloud = PINECONE_CLOUD
|
||||
|
||||
# Initialize Pinecone client for improved performance
|
||||
self.client = Pinecone(api_key=self.api_key)
|
||||
if GRPC_AVAILABLE:
|
||||
# Use gRPC client for better performance (Pinecone recommendation)
|
||||
self.client = PineconeGRPC(
|
||||
api_key=self.api_key,
|
||||
pool_threads=20, # Improved connection pool size
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = True
|
||||
log.info("Using Pinecone gRPC client for optimal performance")
|
||||
else:
|
||||
# Fallback to HTTP client with enhanced connection pooling
|
||||
self.client = Pinecone(
|
||||
api_key=self.api_key,
|
||||
pool_threads=20, # Improved connection pool size
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = False
|
||||
log.info("Using Pinecone HTTP client (gRPC not available)")
|
||||
|
||||
# Persistent executor for batch operations
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
@@ -91,12 +117,53 @@ class PineconeClient(VectorDBBase):
|
||||
log.info(f"Using existing Pinecone index '{self.index_name}'")
|
||||
|
||||
# Connect to the index
|
||||
self.index = self.client.Index(self.index_name)
|
||||
self.index = self.client.Index(
|
||||
self.index_name,
|
||||
pool_threads=20, # Enhanced connection pool for index operations
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize Pinecone index: {e}")
|
||||
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
|
||||
|
||||
def _retry_pinecone_operation(self, operation_func, max_retries=3):
|
||||
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return operation_func()
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
# Check if it's a retryable error (rate limits, network issues, timeouts)
|
||||
is_retryable = any(
|
||||
keyword in error_str
|
||||
for keyword in [
|
||||
"rate limit",
|
||||
"quota",
|
||||
"timeout",
|
||||
"network",
|
||||
"connection",
|
||||
"unavailable",
|
||||
"internal error",
|
||||
"429",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_retryable or attempt == max_retries - 1:
|
||||
# Don't retry for non-retryable errors or on final attempt
|
||||
raise
|
||||
|
||||
# Exponential backoff with jitter
|
||||
delay = (2**attempt) + random.uniform(0, 1)
|
||||
log.warning(
|
||||
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], collection_name_with_prefix: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -223,7 +290,8 @@ class PineconeClient(VectorDBBase):
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -254,7 +322,8 @@ class PineconeClient(VectorDBBase):
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -285,7 +354,8 @@ class PineconeClient(VectorDBBase):
|
||||
log.error(f"Error in async insert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully async inserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -316,7 +386,8 @@ class PineconeClient(VectorDBBase):
|
||||
log.error(f"Error in async upsert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully async upserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -457,10 +528,12 @@ class PineconeClient(VectorDBBase):
|
||||
# This is a limitation of Pinecone - be careful with ID uniqueness
|
||||
self.index.delete(ids=batch_ids)
|
||||
log.debug(
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(
|
||||
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
f"Successfully deleted {len(ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
elif filter:
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import Optional, Literal
|
||||
import requests
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
MODELS = Literal[
|
||||
"sonar",
|
||||
"sonar-pro",
|
||||
"sonar-reasoning",
|
||||
"sonar-reasoning-pro",
|
||||
"sonar-deep-research",
|
||||
]
|
||||
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
@@ -14,6 +24,8 @@ def search_perplexity(
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
model: MODELS = "sonar",
|
||||
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -21,6 +33,9 @@ def search_perplexity(
|
||||
api_key (str): A Perplexity API key
|
||||
query (str): The query to search for
|
||||
count (int): Maximum number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results
|
||||
model (str): The Perplexity model to use (sonar, sonar-pro)
|
||||
search_context_usage (str): Search context usage level (low, medium, high)
|
||||
|
||||
"""
|
||||
|
||||
@@ -33,7 +48,7 @@ def search_perplexity(
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"model": "sonar",
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -43,6 +58,9 @@ def search_perplexity(
|
||||
],
|
||||
"temperature": 0.2, # Lower temperature for more factual responses
|
||||
"stream": False,
|
||||
"web_search_options": {
|
||||
"search_context_usage": search_context_usage,
|
||||
},
|
||||
}
|
||||
|
||||
headers = {
|
||||
|
||||
@@ -420,7 +420,7 @@ def load_b64_image_data(b64_str):
|
||||
try:
|
||||
if "," in b64_str:
|
||||
header, encoded = b64_str.split(",", 1)
|
||||
mime_type = header.split(";")[0]
|
||||
mime_type = header.split(";")[0].lstrip("data:")
|
||||
img_data = base64.b64decode(encoded)
|
||||
else:
|
||||
mime_type = "image/png"
|
||||
@@ -428,7 +428,7 @@ def load_b64_image_data(b64_str):
|
||||
return img_data, mime_type
|
||||
except Exception as e:
|
||||
log.exception(f"Error loading image data: {e}")
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
def load_url_image_data(url, headers=None):
|
||||
|
||||
@@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -159,9 +158,8 @@ async def update_note_by_id(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1232,6 +1232,9 @@ class GenerateChatCompletionForm(BaseModel):
|
||||
stream: Optional[bool] = True
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
|
||||
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None):
|
||||
@@ -1269,7 +1272,9 @@ async def generate_chat_completion(
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
if isinstance(form_data, BaseModel):
|
||||
payload = {**form_data.model_dump(exclude_none=True)}
|
||||
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
|
||||
@@ -1285,11 +1290,7 @@ async def generate_chat_completion(
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, (payload.get("options", {}) or {})
|
||||
)
|
||||
|
||||
payload = apply_model_params_to_body_ollama(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
@@ -1323,7 +1324,7 @@ async def generate_chat_completion(
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
# payload["keep_alive"] = -1 # keep alive forever
|
||||
|
||||
return await send_post_request(
|
||||
url=f"{url}/api/chat",
|
||||
payload=json.dumps(payload),
|
||||
|
||||
@@ -887,6 +887,88 @@ async def generate_chat_completion(
|
||||
await session.close()
|
||||
|
||||
|
||||
async def embeddings(request: Request, form_data: dict, user):
|
||||
"""
|
||||
Calls the embeddings endpoint for OpenAI-compatible providers.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): OpenAI-compatible embeddings payload.
|
||||
user (UserModel): The authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
idx = 0
|
||||
# Prepare payload/body
|
||||
body = json.dumps(form_data)
|
||||
# Find correct backend url/key based on model
|
||||
await get_all_models(request, user=user)
|
||||
model_id = form_data.get("model")
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if model_id in models:
|
||||
idx = models[model_id]["urlIdx"]
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/embeddings",
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
|
||||
@@ -414,6 +414,9 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
@@ -467,6 +470,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||
@@ -520,6 +525,8 @@ class WebConfig(BaseModel):
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
|
||||
EXA_API_KEY: Optional[str] = None
|
||||
PERPLEXITY_API_KEY: Optional[str] = None
|
||||
PERPLEXITY_MODEL: Optional[str] = None
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
|
||||
SOUGOU_API_SID: Optional[str] = None
|
||||
SOUGOU_API_SK: Optional[str] = None
|
||||
WEB_LOADER_ENGINE: Optional[str] = None
|
||||
@@ -571,6 +578,9 @@ class ConfigForm(BaseModel):
|
||||
DOCLING_OCR_ENGINE: Optional[str] = None
|
||||
DOCLING_OCR_LANG: Optional[str] = None
|
||||
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None
|
||||
DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
||||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||||
@@ -744,6 +754,22 @@ async def update_rag_config(
|
||||
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
)
|
||||
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE
|
||||
)
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL
|
||||
)
|
||||
request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = (
|
||||
form_data.DOCLING_PICTURE_DESCRIPTION_API
|
||||
if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None
|
||||
else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API
|
||||
)
|
||||
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
||||
@@ -907,6 +933,10 @@ async def update_rag_config(
|
||||
)
|
||||
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
|
||||
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
|
||||
request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
|
||||
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
|
||||
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||
)
|
||||
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
||||
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
||||
|
||||
@@ -977,6 +1007,9 @@ async def update_rag_config(
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
"DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
"DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
@@ -1030,6 +1063,8 @@ async def update_rag_config(
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||
@@ -1321,9 +1356,14 @@ def process_file(
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
DOCLING_PARAMS={
|
||||
"ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"ocr_lang": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE,
|
||||
"picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL,
|
||||
"picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API,
|
||||
},
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
@@ -1740,19 +1780,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "exa":
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "perplexity":
|
||||
return search_perplexity(
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
model=request.app.state.config.PERPLEXITY_MODEL,
|
||||
search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
)
|
||||
elif engine == "sougou":
|
||||
if (
|
||||
|
||||
@@ -9,6 +9,7 @@ import re
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.task import (
|
||||
title_generation_template,
|
||||
follow_up_generation_template,
|
||||
query_generation_template,
|
||||
image_prompt_generation_template,
|
||||
autocomplete_generation_template,
|
||||
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_FOLLOW_UP_GENERATION: bool
|
||||
ENABLE_TAGS_GENERATION: bool
|
||||
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||
@@ -94,6 +100,13 @@ async def update_task_config(
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
|
||||
form_data.ENABLE_FOLLOW_UP_GENERATION
|
||||
)
|
||||
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -133,6 +146,8 @@ async def update_task_config(
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -231,6 +246,86 @@ async def generate_title(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/follow_up/completions")
|
||||
async def generate_follow_ups(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Follow-up generation is disabled"},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = follow_up_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.FOLLOW_UP_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tags/completions")
|
||||
async def generate_chat_tags(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
@@ -165,22 +165,6 @@ async def update_default_user_permissions(
|
||||
return request.app.state.config.USER_PERMISSIONS
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserRole
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserSettingsBySessionUser
|
||||
############################
|
||||
@@ -333,11 +317,22 @@ async def update_user_by_id(
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id and session_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
if first_user:
|
||||
if user_id == first_user.id:
|
||||
if session_user.id != user_id:
|
||||
# If the user trying to update is the primary admin, and they are not the primary admin themselves
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
if form_data.role != "admin":
|
||||
# If the primary admin is trying to change their own role, prevent it
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
@@ -365,6 +360,7 @@ async def update_user_by_id(
|
||||
updated_user = Users.update_user_by_id(
|
||||
user_id,
|
||||
{
|
||||
"role": form_data.role,
|
||||
"name": form_data.name,
|
||||
"email": form_data.email.lower(),
|
||||
"profile_image_url": form_data.profile_image_url,
|
||||
|
||||
@@ -34,7 +34,7 @@ class CodeForm(BaseModel):
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_verified_user)):
|
||||
async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
formatted_code = black.format_str(form_data.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
|
||||
@@ -2,16 +2,87 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from uuid import uuid4
|
||||
import json
|
||||
from redis.asyncio import Redis
|
||||
from fastapi import Request
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# A dictionary to keep track of active tasks
|
||||
tasks: Dict[str, asyncio.Task] = {}
|
||||
chat_tasks = {}
|
||||
|
||||
|
||||
def cleanup_task(task_id: str, id=None):
|
||||
REDIS_TASKS_KEY = "open-webui:tasks"
|
||||
REDIS_CHAT_TASKS_KEY = "open-webui:tasks:chat"
|
||||
REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands"
|
||||
|
||||
|
||||
def is_redis(request: Request) -> bool:
|
||||
# Called everywhere a request is available to check Redis
|
||||
return hasattr(request.app.state, "redis") and (request.app.state.redis is not None)
|
||||
|
||||
|
||||
async def redis_task_command_listener(app):
|
||||
redis: Redis = app.state.redis
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REDIS_PUBSUB_CHANNEL)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
try:
|
||||
command = json.loads(message["data"])
|
||||
if command.get("action") == "stop":
|
||||
task_id = command.get("task_id")
|
||||
local_task = tasks.get(task_id)
|
||||
if local_task:
|
||||
local_task.cancel()
|
||||
except Exception as e:
|
||||
print(f"Error handling distributed task command: {e}")
|
||||
|
||||
|
||||
### ------------------------------
|
||||
### REDIS-ENABLED HANDLERS
|
||||
### ------------------------------
|
||||
|
||||
|
||||
async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
||||
pipe = redis.pipeline()
|
||||
pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "")
|
||||
if chat_id:
|
||||
pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]):
|
||||
pipe = redis.pipeline()
|
||||
pipe.hdel(REDIS_TASKS_KEY, task_id)
|
||||
if chat_id:
|
||||
pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id)
|
||||
if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0:
|
||||
pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set
|
||||
await pipe.execute()
|
||||
|
||||
|
||||
async def redis_list_tasks(redis: Redis) -> List[str]:
|
||||
return list(await redis.hkeys(REDIS_TASKS_KEY))
|
||||
|
||||
|
||||
async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]:
|
||||
return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}"))
|
||||
|
||||
|
||||
async def redis_send_command(redis: Redis, command: dict):
|
||||
await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command))
|
||||
|
||||
|
||||
async def cleanup_task(request, task_id: str, id=None):
|
||||
"""
|
||||
Remove a completed or canceled task from the global `tasks` dictionary.
|
||||
"""
|
||||
if is_redis(request):
|
||||
await redis_cleanup_task(request.app.state.redis, task_id, id)
|
||||
|
||||
tasks.pop(task_id, None) # Remove the task if it exists
|
||||
|
||||
# If an ID is provided, remove the task from the chat_tasks dictionary
|
||||
@@ -21,7 +92,7 @@ def cleanup_task(task_id: str, id=None):
|
||||
chat_tasks.pop(id, None)
|
||||
|
||||
|
||||
def create_task(coroutine, id=None):
|
||||
async def create_task(request, coroutine, id=None):
|
||||
"""
|
||||
Create a new asyncio task and add it to the global task dictionary.
|
||||
"""
|
||||
@@ -29,7 +100,9 @@ def create_task(coroutine, id=None):
|
||||
task = asyncio.create_task(coroutine) # Create the task
|
||||
|
||||
# Add a done callback for cleanup
|
||||
task.add_done_callback(lambda t: cleanup_task(task_id, id))
|
||||
task.add_done_callback(
|
||||
lambda t: asyncio.create_task(cleanup_task(request, task_id, id))
|
||||
)
|
||||
tasks[task_id] = task
|
||||
|
||||
# If an ID is provided, associate the task with that ID
|
||||
@@ -38,34 +111,46 @@ def create_task(coroutine, id=None):
|
||||
else:
|
||||
chat_tasks[id] = [task_id]
|
||||
|
||||
if is_redis(request):
|
||||
await redis_save_task(request.app.state.redis, task_id, id)
|
||||
|
||||
return task_id, task
|
||||
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Retrieve a task by its task ID.
|
||||
"""
|
||||
return tasks.get(task_id)
|
||||
|
||||
|
||||
def list_tasks():
|
||||
async def list_tasks(request):
|
||||
"""
|
||||
List all currently active task IDs.
|
||||
"""
|
||||
if is_redis(request):
|
||||
return await redis_list_tasks(request.app.state.redis)
|
||||
return list(tasks.keys())
|
||||
|
||||
|
||||
def list_task_ids_by_chat_id(id):
|
||||
async def list_task_ids_by_chat_id(request, id):
|
||||
"""
|
||||
List all tasks associated with a specific ID.
|
||||
"""
|
||||
if is_redis(request):
|
||||
return await redis_list_chat_tasks(request.app.state.redis, id)
|
||||
return chat_tasks.get(id, [])
|
||||
|
||||
|
||||
async def stop_task(task_id: str):
|
||||
async def stop_task(request, task_id: str):
|
||||
"""
|
||||
Cancel a running task and remove it from the global task list.
|
||||
"""
|
||||
if is_redis(request):
|
||||
# PUBSUB: All instances check if they have this task, and stop if so.
|
||||
await redis_send_command(
|
||||
request.app.state.redis,
|
||||
{
|
||||
"action": "stop",
|
||||
"task_id": task_id,
|
||||
},
|
||||
)
|
||||
# Optionally check if task_id still in Redis a few moments later for feedback?
|
||||
return {"status": True, "message": f"Stop signal sent for {task_id}"}
|
||||
|
||||
task = tasks.get(task_id)
|
||||
if not task:
|
||||
raise ValueError(f"Task with ID {task_id} not found.")
|
||||
|
||||
@@ -23,6 +23,7 @@ from open_webui.env import (
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
SRC_LOG_LEVELS,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
)
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
@@ -157,6 +158,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
@@ -225,6 +227,19 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER)
|
||||
if trusted_email and user.email != trusted_email:
|
||||
# Delete the token cookie
|
||||
response.delete_cookie("token")
|
||||
# Delete OAuth token if present
|
||||
if request.cookies.get("oauth_id_token"):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
|
||||
@@ -320,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -424,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
|
||||
90
backend/open_webui/utils/embeddings.py
Normal file
90
backend/open_webui/utils/embeddings.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from fastapi import Request
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.utils.models import check_model_access
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
from open_webui.routers.openai import embeddings as openai_embeddings
|
||||
from open_webui.routers.ollama import (
|
||||
embeddings as ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
||||
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: UserModel,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
"""
|
||||
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): The input data sent to the endpoint.
|
||||
user (UserModel): The authenticated user.
|
||||
bypass_filter (bool): If True, disables access filtering (default False).
|
||||
|
||||
Returns:
|
||||
dict: The embeddings response, following OpenAI API compatibility.
|
||||
"""
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
# Attach extra metadata from request.state if present
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
# If "direct" flag present, use only that model
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data.get("model")
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
# Access filtering
|
||||
if not getattr(request.state, "direct", False):
|
||||
if not bypass_filter and user.role == "user":
|
||||
check_model_access(user, model)
|
||||
|
||||
# Ollama backend
|
||||
if model.get("owned_by") == "ollama":
|
||||
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
||||
response = await ollama_embeddings(
|
||||
request=request,
|
||||
form_data=GenerateEmbeddingsForm(**ollama_payload),
|
||||
user=user,
|
||||
)
|
||||
return convert_embedding_response_ollama_to_openai(response)
|
||||
|
||||
# Default: OpenAI or compatible backend
|
||||
return await openai_embeddings(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
)
|
||||
@@ -32,11 +32,17 @@ from open_webui.socket.main import (
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_follow_ups,
|
||||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
from open_webui.routers.images import (
|
||||
load_b64_image_data,
|
||||
image_generations,
|
||||
GenerateImageForm,
|
||||
upload_image,
|
||||
)
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
process_pipeline_outlet_filter,
|
||||
@@ -692,13 +698,8 @@ def apply_params_to_form_data(form_data, model):
|
||||
params = deep_update(params, custom_params)
|
||||
|
||||
if model.get("ollama"):
|
||||
# Ollama specific parameters
|
||||
form_data["options"] = params
|
||||
|
||||
if "format" in params:
|
||||
form_data["format"] = params["format"]
|
||||
|
||||
if "keep_alive" in params:
|
||||
form_data["keep_alive"] = params["keep_alive"]
|
||||
else:
|
||||
if isinstance(params, dict):
|
||||
for key, value in params.items():
|
||||
@@ -726,12 +727,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -1061,6 +1057,59 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
if tasks and messages:
|
||||
if (
|
||||
TASKS.FOLLOW_UP_GENERATION in tasks
|
||||
and tasks[TASKS.FOLLOW_UP_GENERATION]
|
||||
):
|
||||
res = await generate_follow_ups(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"message_id": metadata["message_id"],
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res and isinstance(res, dict):
|
||||
if len(res.get("choices", [])) == 1:
|
||||
follow_ups_string = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
else:
|
||||
follow_ups_string = ""
|
||||
|
||||
follow_ups_string = follow_ups_string[
|
||||
follow_ups_string.find("{") : follow_ups_string.rfind("}")
|
||||
+ 1
|
||||
]
|
||||
|
||||
try:
|
||||
follow_ups = json.loads(follow_ups_string).get(
|
||||
"follow_ups", []
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"followUps": follow_ups,
|
||||
},
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:follow_ups",
|
||||
"data": {
|
||||
"follow_ups": follow_ups,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
if tasks[TASKS.TITLE_GENERATION]:
|
||||
res = await generate_title(
|
||||
@@ -1286,12 +1335,7 @@ async def process_chat_response(
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -1835,9 +1879,11 @@ async def process_chat_response(
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
reasoning_content = delta.get(
|
||||
"reasoning_content"
|
||||
) or delta.get("reasoning")
|
||||
reasoning_content = (
|
||||
delta.get("reasoning_content")
|
||||
or delta.get("reasoning")
|
||||
or delta.get("thinking")
|
||||
)
|
||||
if reasoning_content:
|
||||
if (
|
||||
not content_blocks
|
||||
@@ -2230,28 +2276,21 @@ async def process_chat_response(
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
stdoutLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
@@ -2262,30 +2301,22 @@ async def process_chat_response(
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
resultLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["result"] = "\n".join(resultLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
@@ -2394,8 +2425,8 @@ async def process_chat_response(
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(
|
||||
post_response_handler(response, events), id=metadata["chat_id"]
|
||||
task_id, _ = await create_task(
|
||||
request, post_response_handler(response, events), id=metadata["chat_id"]
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
|
||||
@@ -208,6 +208,7 @@ def openai_chat_message_template(model: str):
|
||||
def openai_chat_chunk_message_template(
|
||||
model: str,
|
||||
content: Optional[str] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
@@ -220,6 +221,9 @@ def openai_chat_chunk_message_template(
|
||||
if content:
|
||||
template["choices"][0]["delta"]["content"] = content
|
||||
|
||||
if reasoning_content:
|
||||
template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
|
||||
|
||||
if tool_calls:
|
||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||
|
||||
@@ -234,6 +238,7 @@ def openai_chat_chunk_message_template(
|
||||
def openai_chat_completion_message_template(
|
||||
model: str,
|
||||
message: Optional[str] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
@@ -241,8 +246,9 @@ def openai_chat_completion_message_template(
|
||||
template["object"] = "chat.completion"
|
||||
if message is not None:
|
||||
template["choices"][0]["message"] = {
|
||||
"content": message,
|
||||
"role": "assistant",
|
||||
"content": message,
|
||||
**({"reasoning_content": reasoning_content} if reasoning_content else {}),
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -538,7 +538,7 @@ class OAuthManager:
|
||||
# Redirect back to the frontend with the JWT token
|
||||
|
||||
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
|
||||
if redirect_base_url.endswith("/"):
|
||||
if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
|
||||
@@ -175,16 +175,32 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "options" in form_data and "keep_alive" in form_data["options"]:
|
||||
form_data["keep_alive"] = form_data["options"]["keep_alive"]
|
||||
del form_data["options"]["keep_alive"]
|
||||
def parse_json(value: str) -> dict:
|
||||
"""
|
||||
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception as e:
|
||||
return value
|
||||
|
||||
if "options" in form_data and "format" in form_data["options"]:
|
||||
form_data["format"] = form_data["options"]["format"]
|
||||
del form_data["options"]["format"]
|
||||
ollama_root_params = {
|
||||
"format": lambda x: parse_json(x),
|
||||
"keep_alive": lambda x: parse_json(x),
|
||||
"think": bool,
|
||||
}
|
||||
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
for key, value in ollama_root_params.items():
|
||||
if (param := params.get(key, None)) is not None:
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
form_data[key] = value(param)
|
||||
del params[key]
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
form_data["options"] = apply_model_params_to_body(
|
||||
params, (form_data.get("options", {}) or {}), mappings
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
@@ -279,36 +295,48 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
openai_payload.get("messages")
|
||||
)
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
if "tools" in openai_payload:
|
||||
ollama_payload["tools"] = openai_payload["tools"]
|
||||
|
||||
if "format" in openai_payload:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
def parse_json(value: str) -> dict:
|
||||
"""
|
||||
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception as e:
|
||||
return value
|
||||
|
||||
ollama_root_params = {
|
||||
"format": lambda x: parse_json(x),
|
||||
"keep_alive": lambda x: parse_json(x),
|
||||
"think": bool,
|
||||
}
|
||||
|
||||
# Ollama's options field can contain parameters that should be at the root level.
|
||||
for key, value in ollama_root_params.items():
|
||||
if (param := ollama_options.get(key, None)) is not None:
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
ollama_payload[key] = value(param)
|
||||
del ollama_options[key]
|
||||
|
||||
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_tokens" in ollama_options:
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options[
|
||||
"max_tokens"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
del ollama_options["max_tokens"]
|
||||
|
||||
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||
# Comment: Not sure why this is needed, but we'll keep it for compatibility.
|
||||
if "system" in ollama_options:
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options[
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
del ollama_options["system"]
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "keep_alive" in ollama_options:
|
||||
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
|
||||
del ollama_options["keep_alive"]
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
@@ -329,3 +357,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
ollama_payload["format"] = format
|
||||
|
||||
return ollama_payload
|
||||
|
||||
|
||||
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"""
|
||||
Convert an embeddings request payload from OpenAI format to Ollama format.
|
||||
|
||||
Args:
|
||||
openai_payload (dict): The original payload designed for OpenAI API usage.
|
||||
|
||||
Returns:
|
||||
dict: A payload compatible with the Ollama API embeddings endpoint.
|
||||
"""
|
||||
ollama_payload = {"model": openai_payload.get("model")}
|
||||
input_value = openai_payload.get("input")
|
||||
|
||||
# Ollama expects 'input' as a list, and 'prompt' as a single string.
|
||||
if isinstance(input_value, list):
|
||||
ollama_payload["input"] = input_value
|
||||
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
|
||||
else:
|
||||
ollama_payload["input"] = [input_value]
|
||||
ollama_payload["prompt"] = str(input_value)
|
||||
|
||||
# Optionally forward other fields if present
|
||||
for optional_key in ("options", "truncate", "keep_alive"):
|
||||
if optional_key in openai_payload:
|
||||
ollama_payload[optional_key] = openai_payload[optional_key]
|
||||
|
||||
return ollama_payload
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import socketio
|
||||
import redis
|
||||
from redis import asyncio as aioredis
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def parse_redis_service_url(redis_url):
|
||||
@@ -18,23 +17,46 @@ def parse_redis_service_url(redis_url):
|
||||
}
|
||||
|
||||
|
||||
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
def get_redis_connection(
|
||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
||||
):
|
||||
if async_mode:
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Get a master connection from Sentinel
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
# If using sentinel in async mode
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# Standard Redis connection
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
import redis
|
||||
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||
|
||||
@@ -83,6 +83,7 @@ def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
reasoning_content = ollama_response.get("message", {}).get("thinking", None)
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -94,7 +95,7 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
model, message_content, reasoning_content, openai_tool_calls, usage
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -105,6 +106,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", None)
|
||||
reasoning_content = data.get("message", {}).get("thinking", None)
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -118,10 +120,71 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
model, message_content, reasoning_content, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
yield line
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def convert_embedding_response_ollama_to_openai(response) -> dict:
|
||||
"""
|
||||
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
|
||||
|
||||
Args:
|
||||
response (dict): The response from the Ollama API,
|
||||
e.g. {"embedding": [...], "model": "..."}
|
||||
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
|
||||
|
||||
Returns:
|
||||
dict: Response adapted to OpenAI's embeddings API format.
|
||||
e.g. {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "embedding": [...], "index": 0},
|
||||
...
|
||||
],
|
||||
"model": "...",
|
||||
}
|
||||
"""
|
||||
# Ollama batch-style output
|
||||
if isinstance(response, dict) and "embeddings" in response:
|
||||
openai_data = []
|
||||
for i, emb in enumerate(response["embeddings"]):
|
||||
openai_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb.get("embedding"),
|
||||
"index": emb.get("index", i),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": openai_data,
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Ollama single output
|
||||
elif isinstance(response, dict) and "embedding" in response:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": response["embedding"],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Already OpenAI-compatible?
|
||||
elif (
|
||||
isinstance(response, dict)
|
||||
and "data" in response
|
||||
and isinstance(response["data"], list)
|
||||
):
|
||||
return response
|
||||
|
||||
# Fallback: return as is if unrecognized
|
||||
return response
|
||||
|
||||
@@ -207,6 +207,24 @@ def title_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def follow_up_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
fastapi==0.115.7
|
||||
uvicorn[standard]==0.34.0
|
||||
uvicorn[standard]==0.34.2
|
||||
pydantic==2.10.6
|
||||
python-multipart==0.0.20
|
||||
|
||||
@@ -7,14 +7,13 @@ python-socketio==5.13.0
|
||||
python-jose==3.4.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
requests==2.32.3
|
||||
requests==2.32.4
|
||||
aiohttp==3.11.11
|
||||
async-timeout
|
||||
aiocache
|
||||
aiofiles
|
||||
starlette-compress==1.6.0
|
||||
|
||||
|
||||
sqlalchemy==2.0.38
|
||||
alembic==1.14.0
|
||||
peewee==3.18.1
|
||||
@@ -76,13 +75,13 @@ pandas==2.2.3
|
||||
openpyxl==3.1.5
|
||||
pyxlsb==1.0.10
|
||||
xlrd==2.0.1
|
||||
validators==0.34.0
|
||||
validators==0.35.0
|
||||
psutil
|
||||
sentencepiece
|
||||
soundfile==0.13.1
|
||||
azure-ai-documentintelligence==1.0.0
|
||||
azure-ai-documentintelligence==1.0.2
|
||||
|
||||
pillow==11.1.0
|
||||
pillow==11.2.1
|
||||
opencv-python-headless==4.11.0.86
|
||||
rapidocr-onnxruntime==1.4.4
|
||||
rank-bm25==0.2.2
|
||||
|
||||
@@ -14,7 +14,11 @@ if [[ "${WEB_LOADER_ENGINE,,}" == "playwright" ]]; then
|
||||
python -c "import nltk; nltk.download('punkt_tab')"
|
||||
fi
|
||||
|
||||
KEY_FILE=.webui_secret_key
|
||||
if [ -n "${WEBUI_SECRET_KEY_FILE}" ]; then
|
||||
KEY_FILE="${WEBUI_SECRET_KEY_FILE}"
|
||||
else
|
||||
KEY_FILE=".webui_secret_key"
|
||||
fi
|
||||
|
||||
PORT="${PORT:-8080}"
|
||||
HOST="${HOST:-0.0.0.0}"
|
||||
|
||||
@@ -18,6 +18,10 @@ IF /I "%WEB_LOADER_ENGINE%" == "playwright" (
|
||||
)
|
||||
|
||||
SET "KEY_FILE=.webui_secret_key"
|
||||
IF NOT "%WEBUI_SECRET_KEY_FILE%" == "" (
|
||||
SET "KEY_FILE=%WEBUI_SECRET_KEY_FILE%"
|
||||
)
|
||||
|
||||
IF "%PORT%"=="" SET PORT=8080
|
||||
IF "%HOST%"=="" SET HOST=0.0.0.0
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
|
||||
Reference in New Issue
Block a user