mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' of https://github.com/open-webui/open-webui into Dev-Individual-RAG-Config
This commit is contained in:
@@ -901,9 +901,7 @@ TOOL_SERVER_CONNECTIONS = PersistentConfig(
|
||||
####################################
|
||||
|
||||
|
||||
WEBUI_URL = PersistentConfig(
|
||||
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
)
|
||||
WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", ""))
|
||||
|
||||
|
||||
ENABLE_SIGNUP = PersistentConfig(
|
||||
@@ -1413,6 +1411,35 @@ Strictly return in JSON format:
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.follow_up.prompt_template",
|
||||
os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion.
|
||||
### Guidelines:
|
||||
- Write all follow-up questions from the user’s point of view, directed to the assistant.
|
||||
- Make questions concise, clear, and directly related to the discussed topic(s).
|
||||
- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered.
|
||||
- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask.
|
||||
- Use the conversation's primary language; default to English if multilingual.
|
||||
- Response must be a JSON array of strings, no extra text or formatting.
|
||||
### Output:
|
||||
JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
|
||||
"ENABLE_FOLLOW_UP_GENERATION",
|
||||
"task.follow_up.enable",
|
||||
os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||
"ENABLE_TAGS_GENERATION",
|
||||
"task.tags.enable",
|
||||
@@ -1848,6 +1875,61 @@ CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_API_KEY = PersistentConfig(
|
||||
"DATALAB_MARKER_API_KEY",
|
||||
"rag.datalab_marker_api_key",
|
||||
os.environ.get("DATALAB_MARKER_API_KEY", ""),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_LANGS = PersistentConfig(
|
||||
"DATALAB_MARKER_LANGS",
|
||||
"rag.datalab_marker_langs",
|
||||
os.environ.get("DATALAB_MARKER_LANGS", ""),
|
||||
)
|
||||
|
||||
DATALAB_MARKER_USE_LLM = PersistentConfig(
|
||||
"DATALAB_MARKER_USE_LLM",
|
||||
"rag.DATALAB_MARKER_USE_LLM",
|
||||
os.environ.get("DATALAB_MARKER_USE_LLM", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_SKIP_CACHE = PersistentConfig(
|
||||
"DATALAB_MARKER_SKIP_CACHE",
|
||||
"rag.datalab_marker_skip_cache",
|
||||
os.environ.get("DATALAB_MARKER_SKIP_CACHE", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_FORCE_OCR = PersistentConfig(
|
||||
"DATALAB_MARKER_FORCE_OCR",
|
||||
"rag.datalab_marker_force_ocr",
|
||||
os.environ.get("DATALAB_MARKER_FORCE_OCR", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_PAGINATE = PersistentConfig(
|
||||
"DATALAB_MARKER_PAGINATE",
|
||||
"rag.datalab_marker_paginate",
|
||||
os.environ.get("DATALAB_MARKER_PAGINATE", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR = PersistentConfig(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR",
|
||||
"rag.datalab_marker_strip_existing_ocr",
|
||||
os.environ.get("DATALAB_MARKER_STRIP_EXISTING_OCR", "false").lower() == "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION",
|
||||
"rag.datalab_marker_disable_image_extraction",
|
||||
os.environ.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", "false").lower()
|
||||
== "true",
|
||||
)
|
||||
|
||||
DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT",
|
||||
"rag.datalab_marker_output_format",
|
||||
os.environ.get("DATALAB_MARKER_OUTPUT_FORMAT", "markdown"),
|
||||
)
|
||||
|
||||
EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig(
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL",
|
||||
"rag.external_document_loader_url",
|
||||
@@ -1928,6 +2010,11 @@ RAG_RELEVANCE_THRESHOLD = PersistentConfig(
|
||||
"rag.relevance_threshold",
|
||||
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
|
||||
)
|
||||
RAG_HYBRID_BM25_WEIGHT = PersistentConfig(
|
||||
"RAG_HYBRID_BM25_WEIGHT",
|
||||
"rag.hybrid_bm25_weight",
|
||||
float(os.environ.get("RAG_HYBRID_BM25_WEIGHT", "0.5")),
|
||||
)
|
||||
|
||||
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
|
||||
"ENABLE_RAG_HYBRID_SEARCH",
|
||||
@@ -2124,6 +2211,22 @@ RAG_OPENAI_API_KEY = PersistentConfig(
|
||||
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
RAG_AZURE_OPENAI_BASE_URL = PersistentConfig(
|
||||
"RAG_AZURE_OPENAI_BASE_URL",
|
||||
"rag.azure_openai.base_url",
|
||||
os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""),
|
||||
)
|
||||
RAG_AZURE_OPENAI_API_KEY = PersistentConfig(
|
||||
"RAG_AZURE_OPENAI_API_KEY",
|
||||
"rag.azure_openai.api_key",
|
||||
os.getenv("RAG_AZURE_OPENAI_API_KEY", ""),
|
||||
)
|
||||
RAG_AZURE_OPENAI_API_VERSION = PersistentConfig(
|
||||
"RAG_AZURE_OPENAI_API_VERSION",
|
||||
"rag.azure_openai.api_version",
|
||||
os.getenv("RAG_AZURE_OPENAI_API_VERSION", ""),
|
||||
)
|
||||
|
||||
RAG_OLLAMA_BASE_URL = PersistentConfig(
|
||||
"RAG_OLLAMA_BASE_URL",
|
||||
"rag.ollama.url",
|
||||
@@ -2213,6 +2316,12 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig(
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER",
|
||||
"rag.web.search.bypass_web_loader",
|
||||
os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true",
|
||||
)
|
||||
|
||||
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"WEB_SEARCH_RESULT_COUNT",
|
||||
"rag.web.search.result_count",
|
||||
@@ -2238,6 +2347,7 @@ WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
|
||||
WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
@@ -2397,6 +2507,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
|
||||
os.getenv("PERPLEXITY_API_KEY", ""),
|
||||
)
|
||||
|
||||
PERPLEXITY_MODEL = PersistentConfig(
|
||||
"PERPLEXITY_MODEL",
|
||||
"rag.web.search.perplexity_model",
|
||||
os.getenv("PERPLEXITY_MODEL", "sonar"),
|
||||
)
|
||||
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig(
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE",
|
||||
"rag.web.search.perplexity_search_context_usage",
|
||||
os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"),
|
||||
)
|
||||
|
||||
SOUGOU_API_SID = PersistentConfig(
|
||||
"SOUGOU_API_SID",
|
||||
"rag.web.search.sougou_api_sid",
|
||||
|
||||
@@ -111,6 +111,7 @@ class TASKS(str, Enum):
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
FOLLOW_UP_GENERATION = "follow_up_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
|
||||
@@ -349,6 +349,10 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||
)
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None
|
||||
)
|
||||
|
||||
|
||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||
|
||||
@@ -25,10 +25,14 @@ from open_webui.socket.main import (
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@@ -53,9 +57,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
@@ -226,12 +228,7 @@ async def generate_function_chat_completion(
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
}
|
||||
@@ -252,8 +249,13 @@ async def generate_function_chat_completion(
|
||||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user)
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(
|
||||
system, form_data, metadata, user
|
||||
)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
||||
@@ -37,9 +37,11 @@ from fastapi import (
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import FileResponse, JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from starlette_compress import CompressMiddleware
|
||||
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
@@ -196,17 +198,32 @@ from open_webui.config import (
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_TOP_K,
|
||||
RAG_TOP_K_RERANKER,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_HYBRID_BM25_WEIGHT,
|
||||
RAG_ALLOWED_FILE_EXTENSIONS,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
RAG_AZURE_OPENAI_BASE_URL,
|
||||
RAG_AZURE_OPENAI_API_KEY,
|
||||
RAG_AZURE_OPENAI_API_VERSION,
|
||||
RAG_OLLAMA_BASE_URL,
|
||||
RAG_OLLAMA_API_KEY,
|
||||
CHUNK_OVERLAP,
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
DATALAB_MARKER_API_KEY,
|
||||
DATALAB_MARKER_LANGS,
|
||||
DATALAB_MARKER_SKIP_CACHE,
|
||||
DATALAB_MARKER_FORCE_OCR,
|
||||
DATALAB_MARKER_PAGINATE,
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||
DATALAB_MARKER_OUTPUT_FORMAT,
|
||||
DATALAB_MARKER_USE_LLM,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
TIKA_SERVER_URL,
|
||||
@@ -217,8 +234,6 @@ from open_webui.config import (
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY,
|
||||
RAG_TOP_K,
|
||||
RAG_TOP_K_RERANKER,
|
||||
RAG_TEXT_SPLITTER,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
PDF_EXTRACT_IMAGES,
|
||||
@@ -233,6 +248,7 @@ from open_webui.config import (
|
||||
ENABLE_WEB_SEARCH,
|
||||
WEB_SEARCH_ENGINE,
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
WEB_SEARCH_RESULT_COUNT,
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
WEB_SEARCH_TRUST_ENV,
|
||||
@@ -257,6 +273,8 @@ from open_webui.config import (
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
EXA_API_KEY,
|
||||
PERPLEXITY_API_KEY,
|
||||
PERPLEXITY_MODEL,
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
SOUGOU_API_SID,
|
||||
SOUGOU_API_SK,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
@@ -348,10 +366,12 @@ from open_webui.config import (
|
||||
TASK_MODEL_EXTERNAL,
|
||||
ENABLE_TAGS_GENERATION,
|
||||
ENABLE_TITLE_GENERATION,
|
||||
ENABLE_FOLLOW_UP_GENERATION,
|
||||
ENABLE_SEARCH_QUERY_GENERATION,
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
@@ -400,6 +420,7 @@ from open_webui.utils.chat import (
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
from open_webui.utils.middleware import process_chat_payload, process_chat_response
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@@ -638,8 +659,12 @@ app.state.WEBUI_AUTH_SIGNOUT_REDIRECT_URL = WEBUI_AUTH_SIGNOUT_REDIRECT_URL
|
||||
app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
|
||||
|
||||
app.state.USER_COUNT = None
|
||||
|
||||
app.state.TOOLS = {}
|
||||
app.state.TOOL_CONTENTS = {}
|
||||
|
||||
app.state.FUNCTIONS = {}
|
||||
app.state.FUNCTION_CONTENTS = {}
|
||||
|
||||
########################################
|
||||
#
|
||||
@@ -651,6 +676,7 @@ app.state.FUNCTIONS = {}
|
||||
app.state.config.TOP_K = RAG_TOP_K
|
||||
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
|
||||
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
app.state.config.HYBRID_BM25_WEIGHT = RAG_HYBRID_BM25_WEIGHT
|
||||
app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
|
||||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
@@ -662,6 +688,17 @@ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY
|
||||
app.state.config.DATALAB_MARKER_LANGS = DATALAB_MARKER_LANGS
|
||||
app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE
|
||||
app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR
|
||||
app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE
|
||||
app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR
|
||||
app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
)
|
||||
app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM
|
||||
app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT
|
||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
@@ -693,6 +730,10 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL
|
||||
app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY
|
||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION = RAG_AZURE_OPENAI_API_VERSION
|
||||
|
||||
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
||||
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
||||
|
||||
@@ -712,6 +753,7 @@ app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
@@ -739,6 +781,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
app.state.config.EXA_API_KEY = EXA_API_KEY
|
||||
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
|
||||
app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL
|
||||
app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||
app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
|
||||
app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
|
||||
app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL
|
||||
@@ -789,9 +833,13 @@ try:
|
||||
else app.state.config.RAG_OLLAMA_API_KEY
|
||||
),
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
azure_api_version=(
|
||||
app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Load all reranking models that are currently in use
|
||||
# Load all reranking models that are currently in use
|
||||
for engine, model_list in app.state.config.LOADED_RERANKING_MODELS.items():
|
||||
for model in model_list:
|
||||
app.state.rf[model["RAG_RERANKING_MODEL"]] = get_rf(
|
||||
@@ -800,11 +848,10 @@ try:
|
||||
model["RAG_EXTERNAL_RERANKER_URL"],
|
||||
model["RAG_EXTERNAL_RERANKER_API_KEY"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
||||
|
||||
########################################
|
||||
#
|
||||
# CODE EXECUTION
|
||||
@@ -923,6 +970,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
|
||||
app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
|
||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
|
||||
app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
|
||||
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
@@ -930,6 +978,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
|
||||
app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
@@ -973,6 +1024,7 @@ class RedirectMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
|
||||
# Add the middleware to the app
|
||||
app.add_middleware(CompressMiddleware)
|
||||
app.add_middleware(RedirectMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
@@ -1160,6 +1212,37 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
##################################
|
||||
# Embeddings
|
||||
##################################
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
async def embeddings(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible embeddings endpoint.
|
||||
|
||||
This handler:
|
||||
- Performs user/model checks and dispatches to the correct backend.
|
||||
- Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider.
|
||||
|
||||
Args:
|
||||
request (Request): Request context.
|
||||
form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]})
|
||||
user (UserModel): Authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
# Make sure models are loaded in app state
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
# Use generic dispatcher in utils.embeddings
|
||||
return await generate_embeddings(request, form_data, user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def chat_completion(
|
||||
request: Request,
|
||||
@@ -1591,7 +1674,20 @@ async def healthcheck_with_db():
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
||||
|
||||
|
||||
@app.get("/cache/{path:path}")
|
||||
async def serve_cache_file(
|
||||
path: str,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
|
||||
# prevent path traversal
|
||||
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
if not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def swagger_ui_html(*args, **kwargs):
|
||||
|
||||
@@ -129,12 +129,16 @@ class AuthsTable:
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||
if auth:
|
||||
if verify_password(password, auth.password):
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
@@ -155,8 +159,8 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_trusted_header: {email}")
|
||||
def authenticate_user_by_email(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_email: {email}")
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
|
||||
@@ -377,22 +377,47 @@ class ChatTable:
|
||||
return False
|
||||
|
||||
def get_archived_chat_list_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
self,
|
||||
user_id: str,
|
||||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
|
||||
with get_db() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id, archived=True)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
# .limit(limit).offset(skip)
|
||||
.all()
|
||||
)
|
||||
query = db.query(Chat).filter_by(user_id=user_id, archived=True)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by and direction and getattr(Chat, order_by):
|
||||
if direction.lower() == "asc":
|
||||
query = query.order_by(getattr(Chat, order_by).asc())
|
||||
elif direction.lower() == "desc":
|
||||
query = query.order_by(getattr(Chat, order_by).desc())
|
||||
else:
|
||||
raise ValueError("Invalid direction for ordering")
|
||||
else:
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
all_chats = query.all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_chat_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
include_archived: bool = False,
|
||||
filter: Optional[dict] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
) -> list[ChatModel]:
|
||||
@@ -401,7 +426,23 @@ class ChatTable:
|
||||
if not include_archived:
|
||||
query = query.filter_by(archived=False)
|
||||
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(Chat.title.ilike(f"%{query_key}%"))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by and direction and getattr(Chat, order_by):
|
||||
if direction.lower() == "asc":
|
||||
query = query.order_by(getattr(Chat, order_by).asc())
|
||||
elif direction.lower() == "desc":
|
||||
query = query.order_by(getattr(Chat, order_by).desc())
|
||||
else:
|
||||
raise ValueError("Invalid direction for ordering")
|
||||
else:
|
||||
query = query.order_by(Chat.updated_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
@@ -542,7 +583,9 @@ class ChatTable:
|
||||
search_text = search_text.lower().strip()
|
||||
|
||||
if not search_text:
|
||||
return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit)
|
||||
return self.get_chat_list_by_user_id(
|
||||
user_id, include_archived, filter={}, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
search_text_words = search_text.split(" ")
|
||||
|
||||
|
||||
@@ -108,6 +108,54 @@ class FunctionsTable:
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
return None
|
||||
|
||||
def sync_functions(
|
||||
self, user_id: str, functions: list[FunctionModel]
|
||||
) -> list[FunctionModel]:
|
||||
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
|
||||
try:
|
||||
with get_db() as db:
|
||||
# Get existing functions
|
||||
existing_functions = db.query(Function).all()
|
||||
existing_ids = {func.id for func in existing_functions}
|
||||
|
||||
# Prepare a set of new function IDs
|
||||
new_function_ids = {func.id for func in functions}
|
||||
|
||||
# Update or insert functions
|
||||
for func in functions:
|
||||
if func.id in existing_ids:
|
||||
db.query(Function).filter_by(id=func.id).update(
|
||||
{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_func = Function(
|
||||
**{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_func)
|
||||
|
||||
# Remove functions that are no longer present
|
||||
for func in existing_functions:
|
||||
if func.id not in new_function_ids:
|
||||
db.delete(func)
|
||||
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
FunctionModel.model_validate(func)
|
||||
for func in db.query(Function).all()
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing functions for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -207,5 +207,43 @@ class GroupTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_user_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str]
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
# Remove user from groups not in the new list
|
||||
existing_groups = self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in existing_groups:
|
||||
if group.id not in group_ids:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
for group in groups:
|
||||
if user_id not in group.user_ids:
|
||||
group.user_ids.append(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
Groups = GroupTable()
|
||||
|
||||
@@ -95,6 +95,7 @@ class UserRoleUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class UserUpdateForm(BaseModel):
|
||||
role: str
|
||||
name: str
|
||||
email: str
|
||||
profile_image_url: str
|
||||
|
||||
251
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
251
backend/open_webui/retrieval/loaders/datalab_marker.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatalabMarkerLoader:
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_key: str,
|
||||
langs: Optional[str] = None,
|
||||
use_llm: bool = False,
|
||||
skip_cache: bool = False,
|
||||
force_ocr: bool = False,
|
||||
paginate: bool = False,
|
||||
strip_existing_ocr: bool = False,
|
||||
disable_image_extraction: bool = False,
|
||||
output_format: str = None,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_key = api_key
|
||||
self.langs = langs
|
||||
self.use_llm = use_llm
|
||||
self.skip_cache = skip_cache
|
||||
self.force_ocr = force_ocr
|
||||
self.paginate = paginate
|
||||
self.strip_existing_ocr = strip_existing_ocr
|
||||
self.disable_image_extraction = disable_image_extraction
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
ext = filename.rsplit(".", 1)[-1].lower()
|
||||
mime_map = {
|
||||
"pdf": "application/pdf",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"doc": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"html": "text/html",
|
||||
"epub": "application/epub+zip",
|
||||
"png": "image/png",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"tiff": "image/tiff",
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"https://www.datalab.to/api/v1/marker/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
||||
return result
|
||||
except requests.HTTPError as e:
|
||||
log.error(f"Error checking Marker request status: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to check Marker request: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
url = "https://www.datalab.to/api/v1/marker"
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
|
||||
form_data = {
|
||||
"langs": self.langs,
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
url, data=form_data, files=files, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {e}",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
if not check_url:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail="No request_check_url returned."
|
||||
)
|
||||
|
||||
for _ in range(300): # Up to 10 minutes
|
||||
time.sleep(2)
|
||||
try:
|
||||
poll_response = requests.get(check_url, headers=headers)
|
||||
poll_response.raise_for_status()
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
|
||||
if status_val == "complete":
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT, detail="Marker processing timed out"
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
)
|
||||
|
||||
content_key = self.output_format.lower()
|
||||
raw_content = poll_result.get(content_key)
|
||||
|
||||
if content_key == "json":
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif content_key in {"markdown", "html"}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported output format: {self.output_format}",
|
||||
)
|
||||
|
||||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Datalab Marker returned empty content",
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(content_key, "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(full_text)
|
||||
log.info(f"Saved Marker output to: {output_path}")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to write marker output to disk: {e}")
|
||||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": poll_result.get("output_format", self.output_format),
|
||||
"page_count": poll_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
}
|
||||
|
||||
images = poll_result.get("images", {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
metadata[k] = json.dumps(v)
|
||||
elif v is None:
|
||||
metadata[k] = ""
|
||||
|
||||
return [Document(page_content=full_text, metadata=metadata)]
|
||||
@@ -21,9 +21,11 @@ from langchain_community.document_loaders import (
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
@@ -74,7 +76,6 @@ known_source_ext = [
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"msg",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
@@ -145,15 +146,12 @@ class DoclingLoader:
|
||||
)
|
||||
}
|
||||
|
||||
params = {
|
||||
"image_export_mode": "placeholder",
|
||||
"table_mode": "accurate",
|
||||
}
|
||||
params = {"image_export_mode": "placeholder", "table_mode": "accurate"}
|
||||
|
||||
if self.params:
|
||||
if self.params.get("do_picture_classification"):
|
||||
params["do_picture_classification"] = self.params.get(
|
||||
"do_picture_classification"
|
||||
if self.params.get("do_picture_description"):
|
||||
params["do_picture_description"] = self.params.get(
|
||||
"do_picture_description"
|
||||
)
|
||||
|
||||
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
|
||||
@@ -236,6 +234,49 @@ class Loader:
|
||||
mime_type=file_content_type,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
)
|
||||
elif (
|
||||
self.engine == "datalab_marker"
|
||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
||||
and file_ext
|
||||
in [
|
||||
"pdf",
|
||||
"xls",
|
||||
"xlsx",
|
||||
"ods",
|
||||
"doc",
|
||||
"docx",
|
||||
"odt",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"odp",
|
||||
"html",
|
||||
"epub",
|
||||
"png",
|
||||
"jpeg",
|
||||
"jpg",
|
||||
"webp",
|
||||
"gif",
|
||||
"tiff",
|
||||
]
|
||||
):
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
langs=self.kwargs.get("DATALAB_MARKER_LANGS"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
||||
strip_existing_ocr=self.kwargs.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
||||
),
|
||||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
@@ -247,7 +288,7 @@ class Loader:
|
||||
params={
|
||||
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
"do_picture_classification": self.kwargs.get(
|
||||
"do_picture_description": self.kwargs.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION"
|
||||
),
|
||||
},
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
@@ -14,18 +18,37 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
class MistralLoader:
|
||||
"""
|
||||
Enhanced Mistral OCR loader with both sync and async support.
|
||||
Loads documents by processing them through the Mistral OCR API.
|
||||
|
||||
Performance Optimizations:
|
||||
- Differentiated timeouts for different operations
|
||||
- Intelligent retry logic with exponential backoff
|
||||
- Memory-efficient file streaming for large files
|
||||
- Connection pooling and keepalive optimization
|
||||
- Semaphore-based concurrency control for batch processing
|
||||
- Enhanced error handling with retryable error classification
|
||||
"""
|
||||
|
||||
BASE_API_URL = "https://api.mistral.ai/v1"
|
||||
|
||||
def __init__(self, api_key: str, file_path: str):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
file_path: str,
|
||||
timeout: int = 300, # 5 minutes default
|
||||
max_retries: int = 3,
|
||||
enable_debug_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the loader.
|
||||
Initializes the loader with enhanced features.
|
||||
|
||||
Args:
|
||||
api_key: Your Mistral API key.
|
||||
file_path: The local path to the PDF file to process.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
@@ -34,7 +57,46 @@ class MistralLoader:
|
||||
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.debug = enable_debug_logging
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||
# This prevents long-running OCR operations from affecting quick operations
|
||||
# and improves user experience by failing fast on operations that should be quick
|
||||
self.upload_timeout = min(
|
||||
timeout, 120
|
||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = (
|
||||
30 # URL requests should be fast - fail quickly if API is slow
|
||||
)
|
||||
self.ocr_timeout = (
|
||||
timeout # OCR can take the full timeout - this is the heavy operation
|
||||
)
|
||||
self.cleanup_timeout = (
|
||||
30 # Cleanup should be quick - don't hang on file deletion
|
||||
)
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||
self.file_name = os.path.basename(file_path)
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
|
||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Conditional debug logging for performance.
|
||||
|
||||
Only processes debug messages when debug mode is enabled, avoiding
|
||||
string formatting overhead in production environments.
|
||||
"""
|
||||
if self.debug:
|
||||
log.debug(message, *args)
|
||||
|
||||
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
|
||||
"""Checks response status and returns JSON content."""
|
||||
@@ -54,24 +116,154 @@ class MistralLoader:
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
raise # Re-raise after logging
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""Uploads the file to Mistral for OCR processing."""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
file_name = os.path.basename(self.file_path)
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
response.raise_for_status()
|
||||
|
||||
upload_headers = self.headers.copy() # Avoid modifying self.headers
|
||||
|
||||
response = requests.post(
|
||||
url, headers=upload_headers, files=files, data=data
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
|
||||
response_data = self._handle_response(response)
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
raise
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""
|
||||
ENHANCEMENT: Intelligent error classification for retry logic.
|
||||
|
||||
Determines if an error is retryable based on its type and status code.
|
||||
This prevents wasting time retrying errors that will never succeed
|
||||
(like authentication errors) while ensuring transient errors are retried.
|
||||
|
||||
Retryable errors:
|
||||
- Network connection errors (temporary network issues)
|
||||
- Timeouts (server might be temporarily overloaded)
|
||||
- Server errors (5xx status codes - server-side issues)
|
||||
- Rate limiting (429 status - temporary throttling)
|
||||
|
||||
Non-retryable errors:
|
||||
- Authentication errors (401, 403 - won't fix with retry)
|
||||
- Bad request errors (400 - malformed request)
|
||||
- Not found errors (404 - resource doesn't exist)
|
||||
"""
|
||||
if isinstance(error, requests.exceptions.ConnectionError):
|
||||
return True # Network issues are usually temporary
|
||||
if isinstance(error, requests.exceptions.Timeout):
|
||||
return True # Timeouts might resolve on retry
|
||||
if isinstance(error, requests.exceptions.HTTPError):
|
||||
# Only retry on server errors (5xx) or rate limits (429)
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
return status_code >= 500 or status_code == 429
|
||||
return False
|
||||
if isinstance(
|
||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
||||
):
|
||||
return True # Async network/timeout errors are retryable
|
||||
if isinstance(error, aiohttp.ClientResponseError):
|
||||
return error.status >= 500 or error.status == 429
|
||||
return False # All other errors are non-retryable
|
||||
|
||||
def _retry_request_sync(self, request_func, *args, **kwargs):
|
||||
"""
|
||||
ENHANCEMENT: Synchronous retry logic with intelligent error classification.
|
||||
|
||||
Uses exponential backoff with jitter to avoid thundering herd problems.
|
||||
The wait time increases exponentially but is capped at 30 seconds to
|
||||
prevent excessive delays. Only retries errors that are likely to succeed
|
||||
on subsequent attempts.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return request_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Exponential backoff with cap
|
||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
async def _retry_request_async(self, request_func, *args, **kwargs):
|
||||
"""
|
||||
ENHANCEMENT: Async retry logic with intelligent error classification.
|
||||
|
||||
Async version of retry logic that doesn't block the event loop during
|
||||
wait periods. Uses the same exponential backoff strategy as sync version.
|
||||
"""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await request_func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if attempt == self.max_retries - 1 or not self._is_retryable_error(e):
|
||||
raise
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""
|
||||
PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration.
|
||||
|
||||
Uploads the file to Mistral for OCR processing (sync version).
|
||||
Uses context manager for file handling to ensure proper resource cleanup.
|
||||
Although streaming is not enabled for this endpoint, the file is opened
|
||||
in a context manager to minimize memory usage duration.
|
||||
"""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
|
||||
def upload_request():
|
||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (self.file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
|
||||
# NOTE: stream=False is required for this endpoint
|
||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.headers,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.upload_timeout, # Use specialized upload timeout
|
||||
stream=False, # Keep as False for this endpoint
|
||||
)
|
||||
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
@@ -81,16 +273,66 @@ class MistralLoader:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
data=writer,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.upload_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file."""
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=signed_url_headers, params=params)
|
||||
response_data = self._handle_response(response)
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
@@ -100,8 +342,36 @@ class MistralLoader:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=self.url_timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing."""
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
ocr_headers = {
|
||||
@@ -118,43 +388,217 @@ class MistralLoader:
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=ocr_headers, json=payload)
|
||||
ocr_response = self._handle_response(response)
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
log.debug("OCR response: %s", ocr_response)
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.ocr_timeout),
|
||||
) as response:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
|
||||
return ocr_response
|
||||
|
||||
return await self._retry_request_async(ocr_request)
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage."""
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}"
|
||||
# No specific Accept header needed, default or Authorization is usually sufficient
|
||||
|
||||
try:
|
||||
response = requests.delete(url, headers=self.headers)
|
||||
delete_response = self._handle_response(
|
||||
response
|
||||
) # Check status, ignore response body unless needed
|
||||
log.info(
|
||||
f"File deleted successfully: {delete_response}"
|
||||
) # Log the response if available
|
||||
response = requests.delete(
|
||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
||||
)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
# Depending on requirements, you might choose to raise the error here
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
async with session.delete(
|
||||
url=f"{self.BASE_API_URL}/files/{file_id}",
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.cleanup_timeout
|
||||
), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
"""Context manager for HTTP session with optimized settings."""
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=20, # Increased total connection limit for better throughput
|
||||
limit_per_host=10, # Increased per-host limit for API endpoints
|
||||
ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes)
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=60, # Increased keepalive for connection reuse
|
||||
enable_cleanup_closed=True,
|
||||
force_close=False, # Allow connection reuse
|
||||
resolver=aiohttp.AsyncResolver(), # Use async DNS resolver
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.timeout,
|
||||
connect=30, # Connection timeout
|
||||
sock_read=60, # Socket read timeout
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
raise_for_status=False, # We handle status codes manually
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found",
|
||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
||||
)
|
||||
]
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
skipped_pages = 0
|
||||
|
||||
# Process pages in a memory-efficient way
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is None or page_index is None:
|
||||
skipped_pages += 1
|
||||
self._debug_log(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Clean up content efficiently with early exit for empty content
|
||||
if isinstance(page_content, str):
|
||||
cleaned_content = page_content.strip()
|
||||
else:
|
||||
cleaned_content = str(page_content).strip()
|
||||
|
||||
if not cleaned_content:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
continue
|
||||
|
||||
# Create document with optimized metadata
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index + 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
"content_length": len(cleaned_content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
metadata={
|
||||
"error": "no_valid_pages",
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
|
||||
Synchronous version for backward compatibility.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Upload file
|
||||
file_id = self._upload_file()
|
||||
@@ -166,53 +610,30 @@ class MistralLoader:
|
||||
ocr_response = self._process_ocr(signed_url)
|
||||
|
||||
# 4. Process results
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [Document(page_content="No text content found", metadata={})]
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is not None and page_index is not None:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=page_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index
|
||||
+ 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
# Add other relevant metadata from page_data if available/needed
|
||||
# e.g., page_data.get('width'), page_data.get('height')
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found in valid pages", metadata={}
|
||||
)
|
||||
]
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred during the loading process: {e}")
|
||||
# Return an empty list or a specific error document on failure
|
||||
return [Document(page_content=f"Error during processing: {e}", metadata={})]
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Delete file (attempt even if prior steps failed after upload)
|
||||
if file_id:
|
||||
@@ -223,3 +644,124 @@ class MistralLoader:
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
Asynchronous OCR workflow execution with optimized performance.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
# 1. Upload file with streaming
|
||||
file_id = await self._upload_file_async(session)
|
||||
|
||||
# 2. Get signed URL
|
||||
signed_url = await self._get_signed_url_async(session, file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = await self._process_ocr_async(session, signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Cleanup - always attempt file deletion
|
||||
if file_id:
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
max_concurrent: int = 5, # Limit concurrent requests
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Process multiple files concurrently with controlled concurrency.
|
||||
|
||||
Args:
|
||||
loaders: List of MistralLoader instances
|
||||
max_concurrent: Maximum number of concurrent requests
|
||||
|
||||
Returns:
|
||||
List of document lists, one for each loader
|
||||
"""
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(
|
||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
# Use semaphore to control concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
||||
async with semaphore:
|
||||
return await loader.load_async()
|
||||
|
||||
# Process all files with controlled concurrency
|
||||
tasks = [process_with_semaphore(loader) for loader in loaders]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in results
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
# MONITORING: Log comprehensive batch processing statistics
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s: "
|
||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
||||
f"produced {total_docs} total documents"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||
pass
|
||||
@@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT:
|
||||
class ColBERT(BaseReranker):
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -3,12 +3,14 @@ import requests
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalReranker:
|
||||
class ExternalReranker(BaseReranker):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
import requests
|
||||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
|
||||
@@ -116,6 +117,7 @@ def query_doc_with_hybrid_search(
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
) -> dict:
|
||||
try:
|
||||
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
|
||||
@@ -131,9 +133,20 @@ def query_doc_with_hybrid_search(
|
||||
top_k=k,
|
||||
)
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5]
|
||||
)
|
||||
if hybrid_bm25_weight <= 0:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[vector_search_retriever], weights=[1.0]
|
||||
)
|
||||
elif hybrid_bm25_weight >= 1:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever], weights=[1.0]
|
||||
)
|
||||
else:
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[bm25_retriever, vector_search_retriever],
|
||||
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
|
||||
)
|
||||
|
||||
compressor = RerankCompressor(
|
||||
embedding_function=embedding_function,
|
||||
top_n=k_reranker,
|
||||
@@ -313,6 +326,7 @@ def query_collection_with_hybrid_search(
|
||||
reranking_function,
|
||||
k_reranker: int,
|
||||
r: float,
|
||||
hybrid_bm25_weight: float,
|
||||
) -> dict:
|
||||
results = []
|
||||
error = False
|
||||
@@ -346,6 +360,7 @@ def query_collection_with_hybrid_search(
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
)
|
||||
return result, None
|
||||
except Exception as e:
|
||||
@@ -386,12 +401,13 @@ def get_embedding_function(
|
||||
url,
|
||||
key,
|
||||
embedding_batch_size,
|
||||
azure_api_version=None,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query, prefix=None, user=None: embedding_function.encode(
|
||||
query, **({"prompt": prefix} if prefix else {})
|
||||
).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
|
||||
func = lambda query, prefix=None, user=None: generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
@@ -400,6 +416,7 @@ def get_embedding_function(
|
||||
url=url,
|
||||
key=key,
|
||||
user=user,
|
||||
azure_api_version=azure_api_version,
|
||||
)
|
||||
|
||||
def generate_multiple(query, prefix, user, func):
|
||||
@@ -433,6 +450,7 @@ def get_sources_from_files(
|
||||
reranking_function,
|
||||
k_reranker,
|
||||
r,
|
||||
hybrid_bm25_weight,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
):
|
||||
@@ -550,6 +568,7 @@ def get_sources_from_files(
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=hybrid_bm25_weight,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
@@ -681,6 +700,60 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_azure_openai_batch_embeddings(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
url: str,
|
||||
key: str = "",
|
||||
version: str = "",
|
||||
prefix: str = None,
|
||||
user: UserModel = None,
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
log.debug(
|
||||
f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}"
|
||||
)
|
||||
json_data = {"input": texts}
|
||||
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
|
||||
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
|
||||
|
||||
url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
|
||||
|
||||
for _ in range(5):
|
||||
r = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": key,
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
json=json_data,
|
||||
)
|
||||
if r.status_code == 429:
|
||||
retry = float(r.headers.get("Retry-After", "1"))
|
||||
time.sleep(retry)
|
||||
continue
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise Exception("Something went wrong :/")
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error generating azure openai batch embeddings: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str,
|
||||
texts: list[str],
|
||||
@@ -745,38 +818,33 @@ def generate_embeddings(
|
||||
text = f"{prefix}{text}"
|
||||
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": text,
|
||||
"url": url,
|
||||
"key": key,
|
||||
"prefix": prefix,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"prefix": prefix,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
**{
|
||||
"model": model,
|
||||
"texts": text if isinstance(text, list) else [text],
|
||||
"url": url,
|
||||
"key": key,
|
||||
"prefix": prefix,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "openai":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(
|
||||
model, text, url, key, prefix, user
|
||||
)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(
|
||||
model, [text], url, key, prefix, user
|
||||
)
|
||||
embeddings = generate_openai_batch_embeddings(
|
||||
model, text if isinstance(text, list) else [text], url, key, prefix, user
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
elif engine == "azure_openai":
|
||||
azure_api_version = kwargs.get("azure_api_version", "")
|
||||
embeddings = generate_azure_openai_batch_embeddings(
|
||||
model,
|
||||
text if isinstance(text, list) else [text],
|
||||
url,
|
||||
key,
|
||||
azure_api_version,
|
||||
prefix,
|
||||
user,
|
||||
)
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
import time # for measuring elapsed time
|
||||
from pinecone import ServerlessSpec
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
# Add gRPC support for better performance (Pinecone best practice)
|
||||
try:
|
||||
from pinecone.grpc import PineconeGRPC
|
||||
|
||||
GRPC_AVAILABLE = True
|
||||
except ImportError:
|
||||
GRPC_AVAILABLE = False
|
||||
|
||||
import asyncio # for async upserts
|
||||
import functools # for partial binding in async tasks
|
||||
|
||||
import concurrent.futures # for parallel batch upserts
|
||||
from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
|
||||
import random # for jitter in retry backoff
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
@@ -47,10 +55,25 @@ class PineconeClient(VectorDBBase):
|
||||
self.metric = PINECONE_METRIC
|
||||
self.cloud = PINECONE_CLOUD
|
||||
|
||||
# Initialize Pinecone gRPC client for improved performance
|
||||
self.client = PineconeGRPC(
|
||||
api_key=self.api_key, environment=self.environment, cloud=self.cloud
|
||||
)
|
||||
# Initialize Pinecone client for improved performance
|
||||
if GRPC_AVAILABLE:
|
||||
# Use gRPC client for better performance (Pinecone recommendation)
|
||||
self.client = PineconeGRPC(
|
||||
api_key=self.api_key,
|
||||
pool_threads=20, # Improved connection pool size
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = True
|
||||
log.info("Using Pinecone gRPC client for optimal performance")
|
||||
else:
|
||||
# Fallback to HTTP client with enhanced connection pooling
|
||||
self.client = Pinecone(
|
||||
api_key=self.api_key,
|
||||
pool_threads=20, # Improved connection pool size
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = False
|
||||
log.info("Using Pinecone HTTP client (gRPC not available)")
|
||||
|
||||
# Persistent executor for batch operations
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
@@ -94,12 +117,53 @@ class PineconeClient(VectorDBBase):
|
||||
log.info(f"Using existing Pinecone index '{self.index_name}'")
|
||||
|
||||
# Connect to the index
|
||||
self.index = self.client.Index(self.index_name)
|
||||
self.index = self.client.Index(
|
||||
self.index_name,
|
||||
pool_threads=20, # Enhanced connection pool for index operations
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize Pinecone index: {e}")
|
||||
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
|
||||
|
||||
def _retry_pinecone_operation(self, operation_func, max_retries=3):
|
||||
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return operation_func()
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
# Check if it's a retryable error (rate limits, network issues, timeouts)
|
||||
is_retryable = any(
|
||||
keyword in error_str
|
||||
for keyword in [
|
||||
"rate limit",
|
||||
"quota",
|
||||
"timeout",
|
||||
"network",
|
||||
"connection",
|
||||
"unavailable",
|
||||
"internal error",
|
||||
"429",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
]
|
||||
)
|
||||
|
||||
if not is_retryable or attempt == max_retries - 1:
|
||||
# Don't retry for non-retryable errors or on final attempt
|
||||
raise
|
||||
|
||||
# Exponential backoff with jitter
|
||||
delay = (2**attempt) + random.uniform(0, 1)
|
||||
log.warning(
|
||||
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], collection_name_with_prefix: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -147,8 +211,8 @@ class PineconeClient(VectorDBBase):
|
||||
metadatas = []
|
||||
|
||||
for match in matches:
|
||||
metadata = match.get("metadata", {})
|
||||
ids.append(match["id"])
|
||||
metadata = getattr(match, "metadata", {}) or {}
|
||||
ids.append(match.id if hasattr(match, "id") else match["id"])
|
||||
documents.append(metadata.get("text", ""))
|
||||
metadatas.append(metadata)
|
||||
|
||||
@@ -174,7 +238,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
include_metadata=False,
|
||||
)
|
||||
return len(response.matches) > 0
|
||||
matches = getattr(response, "matches", []) or []
|
||||
return len(matches) > 0
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error checking collection '{collection_name_with_prefix}': {e}"
|
||||
@@ -225,7 +290,8 @@ class PineconeClient(VectorDBBase):
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -256,7 +322,8 @@ class PineconeClient(VectorDBBase):
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.info(
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -287,7 +354,8 @@ class PineconeClient(VectorDBBase):
|
||||
log.error(f"Error in async insert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully async inserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -318,35 +386,10 @@ class PineconeClient(VectorDBBase):
|
||||
log.error(f"Error in async upsert batch: {result}")
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
f"Successfully async upserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Perform a streaming upsert over gRPC for performance testing."""
|
||||
if not items:
|
||||
log.warning("No items to upsert via streaming")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Open a streaming upsert channel
|
||||
stream = self.index.streaming_upsert()
|
||||
try:
|
||||
for point in points:
|
||||
# send each point over the stream
|
||||
stream.send(point)
|
||||
# close the stream to finalize
|
||||
stream.close()
|
||||
log.info(
|
||||
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during streaming upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
@@ -374,7 +417,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
if not query_response.matches:
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
if not matches:
|
||||
# Return empty result if no matches
|
||||
return SearchResult(
|
||||
ids=[[]],
|
||||
@@ -384,13 +428,13 @@ class PineconeClient(VectorDBBase):
|
||||
)
|
||||
|
||||
# Convert to GetResult format
|
||||
get_result = self._result_to_get_result(query_response.matches)
|
||||
get_result = self._result_to_get_result(matches)
|
||||
|
||||
# Calculate normalized distances based on metric
|
||||
distances = [
|
||||
[
|
||||
self._normalize_distance(match.score)
|
||||
for match in query_response.matches
|
||||
self._normalize_distance(getattr(match, "score", 0.0))
|
||||
for match in matches
|
||||
]
|
||||
]
|
||||
|
||||
@@ -432,7 +476,8 @@ class PineconeClient(VectorDBBase):
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {e}")
|
||||
@@ -456,7 +501,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting collection '{collection_name}': {e}")
|
||||
@@ -482,10 +528,12 @@ class PineconeClient(VectorDBBase):
|
||||
# This is a limitation of Pinecone - be careful with ID uniqueness
|
||||
self.index.delete(ids=batch_ids)
|
||||
log.debug(
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(
|
||||
f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
|
||||
f"Successfully deleted {len(ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
elif filter:
|
||||
@@ -516,12 +564,12 @@ class PineconeClient(VectorDBBase):
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Shut down the gRPC channel and thread pool."""
|
||||
"""Shut down resources."""
|
||||
try:
|
||||
self.client.close()
|
||||
log.info("Pinecone gRPC channel closed.")
|
||||
# The new Pinecone client doesn't need explicit closing
|
||||
pass
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to close Pinecone gRPC channel: {e}")
|
||||
log.warning(f"Failed to clean up Pinecone resources: {e}")
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import Optional, Literal
|
||||
import requests
|
||||
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
MODELS = Literal[
|
||||
"sonar",
|
||||
"sonar-pro",
|
||||
"sonar-reasoning",
|
||||
"sonar-reasoning-pro",
|
||||
"sonar-deep-research",
|
||||
]
|
||||
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
@@ -14,6 +24,8 @@ def search_perplexity(
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
model: MODELS = "sonar",
|
||||
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -21,6 +33,9 @@ def search_perplexity(
|
||||
api_key (str): A Perplexity API key
|
||||
query (str): The query to search for
|
||||
count (int): Maximum number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results
|
||||
model (str): The Perplexity model to use (sonar, sonar-pro)
|
||||
search_context_usage (str): Search context usage level (low, medium, high)
|
||||
|
||||
"""
|
||||
|
||||
@@ -33,7 +48,7 @@ def search_perplexity(
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"model": "sonar",
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -43,6 +58,9 @@ def search_perplexity(
|
||||
],
|
||||
"temperature": 0.2, # Lower temperature for more factual responses
|
||||
"stream": False,
|
||||
"web_search_options": {
|
||||
"search_context_usage": search_context_usage,
|
||||
},
|
||||
}
|
||||
|
||||
headers = {
|
||||
|
||||
@@ -42,7 +42,9 @@ def search_searchapi(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result.get("title"), snippet=result.get("snippet")
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -42,7 +42,9 @@ def search_serpapi(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result.get("title"), snippet=result.get("snippet")
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -8,6 +8,8 @@ from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
@@ -18,6 +20,7 @@ from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -93,12 +96,9 @@ def is_audio_conversion_required(file_path):
|
||||
# File is AAC/mp4a audio, recommend mp3 conversion
|
||||
return True
|
||||
|
||||
# If the codec name or file extension is in the supported formats
|
||||
if (
|
||||
codec_name in SUPPORTED_FORMATS
|
||||
or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS
|
||||
):
|
||||
return False # Already supported
|
||||
# If the codec name is in the supported formats
|
||||
if codec_name in SUPPORTED_FORMATS:
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -527,11 +527,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcription_handler(request, file_path):
|
||||
def transcription_handler(request, file_path, metadata):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -543,7 +545,7 @@ def transcription_handler(request, file_path):
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=WHISPER_LANGUAGE,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
@@ -569,7 +571,14 @@ def transcription_handler(request, file_path):
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
data={
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
@@ -777,8 +786,8 @@ def transcription_handler(request, file_path):
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
log.info(f"transcribe: {file_path}")
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
file_path = convert_audio_to_mp3(file_path)
|
||||
@@ -804,7 +813,7 @@ def transcribe(request: Request, file_path):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path)
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
@@ -812,10 +821,9 @@ def transcribe(request: Request, file_path):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as transcribe_exc:
|
||||
log.exception(f"Error transcribing chunk: {transcribe_exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error during transcription.",
|
||||
detail=f"Error transcribing chunk: {transcribe_exc}",
|
||||
)
|
||||
finally:
|
||||
# Clean up only the temporary chunks, never the original file
|
||||
@@ -897,6 +905,7 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
@@ -926,7 +935,12 @@ def transcription(
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
result = transcribe(request, file_path)
|
||||
metadata = None
|
||||
|
||||
if language:
|
||||
metadata = {"language": language}
|
||||
|
||||
result = transcribe(request, file_path, metadata)
|
||||
|
||||
return {
|
||||
**result,
|
||||
|
||||
@@ -19,12 +19,14 @@ from open_webui.models.auths import (
|
||||
UserResponse,
|
||||
)
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
@@ -299,7 +301,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
500, detail="Internal error occurred during LDAP user creation."
|
||||
)
|
||||
|
||||
user = Auths.authenticate_user_by_trusted_header(email)
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
|
||||
if user:
|
||||
expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN)
|
||||
@@ -363,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||
|
||||
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
trusted_name = trusted_email
|
||||
email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
name = email
|
||||
|
||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
trusted_name = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
|
||||
)
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||
|
||||
if not Users.get_user_by_email(email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
SignupForm(email=email, password=str(uuid.uuid4()), name=name),
|
||||
)
|
||||
user = Auths.authenticate_user_by_trusted_header(trusted_email)
|
||||
|
||||
user = Auths.authenticate_user_by_email(email)
|
||||
if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin":
|
||||
group_names = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_GROUPS_HEADER, ""
|
||||
).split(",")
|
||||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_user_groups_by_group_names(user.id, group_names)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
admin_password = "admin"
|
||||
|
||||
@@ -76,17 +76,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
|
||||
@router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_user_id(
|
||||
user_id: str,
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_admin_user),
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
):
|
||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Chats.get_chat_list_by_user_id(
|
||||
user_id, include_archived=True, skip=skip, limit=limit
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
|
||||
@@ -194,10 +211,10 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user)
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatResponse])
|
||||
@router.get("/pinned", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
@@ -267,9 +284,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
|
||||
|
||||
@router.get("/archived", response_model=list[ChatTitleIdResponse])
|
||||
async def get_archived_session_user_chat_list(
|
||||
user=Depends(get_verified_user), skip: int = 0, limit: int = 50
|
||||
page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
chat_list = [
|
||||
ChatTitleIdResponse(**chat.model_dump())
|
||||
for chat in Chats.get_archived_chat_list_by_user_id(
|
||||
user.id,
|
||||
filter=filter,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
]
|
||||
|
||||
return chat_list
|
||||
|
||||
|
||||
############################
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -10,6 +11,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -85,20 +87,33 @@ def has_access_to_file(
|
||||
def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = None,
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
knowledge_id: Optional[str] = Form(None)
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
file_metadata = file_metadata if file_metadata else {}
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
|
||||
)
|
||||
file_metadata = metadata if metadata else {}
|
||||
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
file_extension = os.path.splitext(filename)[1]
|
||||
if request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
@@ -146,21 +161,16 @@ def upload_file(
|
||||
"video/webm"
|
||||
}:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", ""), knowledge_id=knowledge_id),
|
||||
user=user,
|
||||
)
|
||||
elif file.content_type not in [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"video/mp4",
|
||||
"video/ogg",
|
||||
"video/quicktime",
|
||||
]:
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id, knowledge_id=knowledge_id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
@@ -191,7 +201,7 @@ def upload_file(
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import logging
|
||||
import aiohttp
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -9,12 +12,18 @@ from open_webui.models.functions import (
|
||||
FunctionResponse,
|
||||
Functions,
|
||||
)
|
||||
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
replace_imports,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
@@ -42,6 +51,97 @@ async def get_functions(user=Depends(get_admin_user)):
|
||||
return Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_function_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
function_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the function"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": function_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing function: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# SyncFunctions
|
||||
############################
|
||||
|
||||
|
||||
class SyncFunctionsForm(FunctionForm):
|
||||
functions: list[FunctionModel] = []
|
||||
|
||||
|
||||
@router.post("/sync", response_model=Optional[FunctionModel])
|
||||
async def sync_functions(
|
||||
request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewFunction
|
||||
############################
|
||||
@@ -262,8 +362,9 @@ async def get_function_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -287,8 +388,9 @@ async def update_function_valves_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -347,8 +449,9 @@ async def get_function_user_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
@@ -368,8 +471,9 @@ async def update_function_user_valves_by_id(
|
||||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
|
||||
@@ -333,10 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
@@ -450,7 +451,7 @@ def load_url_image_data(url, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
def upload_image(request, image_data, content_type, metadata, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
@@ -459,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
@@ -526,7 +527,7 @@ async def image_generations(
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
@@ -560,7 +561,7 @@ async def image_generations(
|
||||
image_data, content_type = load_b64_image_data(
|
||||
image["bytesBase64Encoded"]
|
||||
)
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
|
||||
return images
|
||||
@@ -611,9 +612,9 @@ async def image_generations(
|
||||
image_data, content_type = load_url_image_data(image["url"], headers)
|
||||
url = upload_image(
|
||||
request,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
image_data,
|
||||
content_type,
|
||||
form_data.model_dump(exclude_none=True),
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
@@ -664,9 +665,9 @@ async def image_generations(
|
||||
image_data, content_type = load_b64_image_data(image)
|
||||
url = upload_image(
|
||||
request,
|
||||
{**data, "info": res["info"]},
|
||||
image_data,
|
||||
content_type,
|
||||
{**data, "info": res["info"]},
|
||||
user,
|
||||
)
|
||||
images.append({"url": url})
|
||||
|
||||
@@ -124,9 +124,8 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -159,9 +158,8 @@ async def update_note_by_id(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
@@ -199,9 +197,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
if user.role != "admin" or (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -9,6 +9,8 @@ import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
@@ -300,6 +302,22 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
def merge_ollama_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
@@ -364,23 +382,8 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
@@ -388,6 +391,22 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
loaded_models = await get_ollama_loaded_models(request, user=user)
|
||||
expires_map = {
|
||||
m["name"]: m["expires_at"]
|
||||
for m in loaded_models["models"]
|
||||
if "expires_at" in m
|
||||
}
|
||||
|
||||
for m in models["models"]:
|
||||
if m["name"] in expires_map:
|
||||
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
||||
dt = datetime.fromisoformat(expires_map[m["name"]])
|
||||
m["expires_at"] = int(dt.timestamp())
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to get loaded models: {e}")
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
@@ -468,6 +487,68 @@ async def get_ollama_tags(
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/ps", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
models = {
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
)
|
||||
)
|
||||
}
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
@@ -541,36 +622,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
return {"version": False}
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/ps",
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/api/unload")
|
||||
async def unload_model(
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
model_name = form_data.name
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing 'name' of model to unload."
|
||||
)
|
||||
|
||||
# Refresh/load models if needed, get mapping from name to URLs
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
# Canonicalize model name (if not supplied with version)
|
||||
if ":" not in model_name:
|
||||
model_name = f"{model_name}:latest"
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
|
||||
)
|
||||
url_indices = models[model_name]["urls"]
|
||||
|
||||
# Send unload to ALL url_indices
|
||||
results = []
|
||||
errors = []
|
||||
for idx in url_indices:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
)
|
||||
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id and model_name.startswith(f"{prefix_id}."):
|
||||
model_name = model_name[len(f"{prefix_id}.") :]
|
||||
|
||||
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
|
||||
|
||||
try:
|
||||
res = await send_post_request(
|
||||
url=f"{url}/api/generate",
|
||||
payload=json.dumps(payload),
|
||||
stream=False,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
results.append({"url_idx": idx, "success": True, "response": res})
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to unload model on node {idx}: {e}")
|
||||
errors.append({"url_idx": idx, "success": False, "error": str(e)})
|
||||
|
||||
if len(errors) > 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@router.post("/api/pull")
|
||||
@router.post("/api/pull/{url_idx}")
|
||||
async def pull_model(
|
||||
@@ -1164,13 +1283,14 @@ async def generate_chat_completion(
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
if payload.get("options") is None:
|
||||
payload["options"] = {}
|
||||
system = params.pop("system", None)
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
payload["options"] = apply_model_params_to_body_ollama(
|
||||
params, payload["options"]
|
||||
params, (payload.get("options", {}) or {})
|
||||
)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -1352,8 +1472,10 @@ async def generate_openai_chat_completion(
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
|
||||
@@ -715,8 +715,12 @@ async def generate_chat_completion(
|
||||
model_id = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, metadata, user)
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(system, payload, metadata, user)
|
||||
|
||||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
@@ -883,6 +887,88 @@ async def generate_chat_completion(
|
||||
await session.close()
|
||||
|
||||
|
||||
async def embeddings(request: Request, form_data: dict, user):
|
||||
"""
|
||||
Calls the embeddings endpoint for OpenAI-compatible providers.
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): OpenAI-compatible embeddings payload.
|
||||
user (UserModel): The authenticated user.
|
||||
|
||||
Returns:
|
||||
dict: OpenAI-compatible embeddings response.
|
||||
"""
|
||||
idx = 0
|
||||
# Prepare payload/body
|
||||
body = json.dumps(form_data)
|
||||
# Find correct backend url/key based on model
|
||||
await get_all_models(request, user=user)
|
||||
model_id = form_data.get("model")
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if model_id in models:
|
||||
idx = models[model_id]["urlIdx"]
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/embeddings",
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
|
||||
@@ -254,6 +254,11 @@ async def get_embedding_config(request: Request, collectionForm: Optional[Collec
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
}),
|
||||
"azure_openai_config": rag_config.get("azure_openai_config", {
|
||||
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||||
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||||
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
@@ -267,9 +272,16 @@ class OllamaConfigForm(BaseModel):
|
||||
key: str
|
||||
|
||||
|
||||
class AzureOpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
version: str
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
ollama_config: Optional[OllamaConfigForm] = None
|
||||
azure_openai_config: Optional[AzureOpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
embedding_batch_size: Optional[int] = 1
|
||||
@@ -405,14 +417,18 @@ async def update_embedding_config(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||
form_data.openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||
form_data.openai_config.key
|
||||
)
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE in [
|
||||
"ollama",
|
||||
"openai",
|
||||
"azure_openai",
|
||||
]:
|
||||
if form_data.openai_config is not None:
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL = (
|
||||
form_data.openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_OPENAI_API_KEY = (
|
||||
form_data.openai_config.key
|
||||
)
|
||||
|
||||
if form_data.ollama_config is not None:
|
||||
request.app.state.config.RAG_OLLAMA_BASE_URL = (
|
||||
@@ -422,6 +438,61 @@ async def update_embedding_config(
|
||||
form_data.ollama_config.key
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
if form_data.azure_openai_config is not None:
|
||||
request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = (
|
||||
form_data.azure_openai_config.url
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_KEY = (
|
||||
form_data.azure_openai_config.key
|
||||
)
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = (
|
||||
form_data.azure_openai_config.version
|
||||
)
|
||||
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
|
||||
form_data.embedding_batch_size
|
||||
)
|
||||
@@ -440,14 +511,27 @@ async def update_embedding_config(
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
else (
|
||||
request.app.state.config.RAG_OLLAMA_BASE_URL
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||
else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL
|
||||
)
|
||||
),
|
||||
(
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
else (
|
||||
request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama"
|
||||
else request.app.state.config.RAG_AZURE_OPENAI_API_KEY
|
||||
)
|
||||
),
|
||||
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||
azure_api_version=(
|
||||
request.app.state.config.RAG_AZURE_OPENAI_API_VERSION
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai"
|
||||
else None
|
||||
),
|
||||
)
|
||||
# add model to state for reloading on startup
|
||||
request.app.state.config.LOADED_EMBEDDING_MODELS[request.app.state.config.RAG_EMBEDDING_ENGINE].append(request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
@@ -470,9 +554,13 @@ async def update_embedding_config(
|
||||
"url": request.app.state.config.RAG_OLLAMA_BASE_URL,
|
||||
"key": request.app.state.config.RAG_OLLAMA_API_KEY,
|
||||
},
|
||||
"azure_openai_config": {
|
||||
"url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL,
|
||||
"key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY,
|
||||
"version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION,
|
||||
},
|
||||
"LOADED_EMBEDDING_MODELS": request.app.state.config.LOADED_EMBEDDING_MODELS,
|
||||
"DOWNLOADED_EMBEDDING_MODELS": request.app.state.config.DOWNLOADED_EMBEDDING_MODELS,
|
||||
"message": "Embedding configuration updated globally.",
|
||||
}
|
||||
except Exception as e:
|
||||
log.exception(f"Problem updating embedding model: {e}")
|
||||
@@ -508,9 +596,19 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"ENABLE_RAG_HYBRID_SEARCH": rag_config.get("ENABLE_RAG_HYBRID_SEARCH", request.app.state.config.ENABLE_RAG_HYBRID_SEARCH),
|
||||
"TOP_K_RERANKER": rag_config.get("TOP_K_RERANKER", request.app.state.config.TOP_K_RERANKER),
|
||||
"RELEVANCE_THRESHOLD": rag_config.get("RELEVANCE_THRESHOLD", request.app.state.config.RELEVANCE_THRESHOLD),
|
||||
"HYBRID_BM25_WEIGHT": rag_config.get("HYBRID_BM25_WEIGHT", request.app.state.config.HYBRID_BM25_WEIGHT),
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": rag_config.get("CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE),
|
||||
"PDF_EXTRACT_IMAGES": rag_config.get("PDF_EXTRACT_IMAGES", request.app.state.config.PDF_EXTRACT_IMAGES),
|
||||
"DATALAB_MARKER_API_KEY": rag_config.get("DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY),
|
||||
"DATALAB_MARKER_LANGS": rag_config.get("DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS),
|
||||
"DATALAB_MARKER_SKIP_CACHE": rag_config.get("DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE),
|
||||
"DATALAB_MARKER_FORCE_OCR": rag_config.get("DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR),
|
||||
"DATALAB_MARKER_PAGINATE": rag_config.get("DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE),
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR": rag_config.get("DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR),
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": rag_config.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION),
|
||||
"DATALAB_MARKER_USE_LLM": rag_config.get("DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM),
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT": rag_config.get("DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT),
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": rag_config.get("EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL),
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": rag_config.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY),
|
||||
"TIKA_SERVER_URL": rag_config.get("TIKA_SERVER_URL", request.app.state.config.TIKA_SERVER_URL),
|
||||
@@ -546,6 +644,7 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": web_config.get("WEB_SEARCH_CONCURRENT_REQUESTS", request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS),
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": web_config.get("WEB_SEARCH_DOMAIN_FILTER_LIST", request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST),
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": web_config.get("BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL", request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL),
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": web_config.get("BYPASS_WEB_SEARCH_WEB_LOADER", request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER),
|
||||
"SEARXNG_QUERY_URL": web_config.get("SEARXNG_QUERY_URL", request.app.state.config.SEARXNG_QUERY_URL),
|
||||
"YACY_QUERY_URL": web_config.get("YACY_QUERY_URL", request.app.state.config.YACY_QUERY_URL),
|
||||
"YACY_USERNAME": web_config.get("YACY_QUERY_USERNAME",request.app.state.config.YACY_USERNAME),
|
||||
@@ -570,6 +669,8 @@ async def get_rag_config(request: Request, collectionForm: CollectionForm, user=
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": web_config.get("BING_SEARCH_V7_SUBSCRIPTION_KEY", request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY),
|
||||
"EXA_API_KEY": web_config.get("EXA_API_KEY", request.app.state.config.EXA_API_KEY),
|
||||
"PERPLEXITY_API_KEY": web_config.get("PERPLEXITY_API_KEY", request.app.state.config.PERPLEXITY_API_KEY),
|
||||
"PERPLEXITY_MODEL": web_config.get("PERPLEXITY_MODEL", request.app.state.config.PERPLEXITY_MODEL),
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": web_config.get("PERPLEXITY_SEARCH_CONTEXT_USAGE", request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE),
|
||||
"SOUGOU_API_SID": web_config.get("SOUGOU_API_SID", request.app.state.config.SOUGOU_API_SID),
|
||||
"SOUGOU_API_SK": web_config.get("SOUGOU_API_SK", request.app.state.config.SOUGOU_API_SK),
|
||||
"WEB_LOADER_ENGINE": web_config.get("WEB_LOADER_ENGINE", request.app.state.config.WEB_LOADER_ENGINE),
|
||||
@@ -603,6 +704,7 @@ class WebConfig(BaseModel):
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
|
||||
SEARXNG_QUERY_URL: Optional[str] = None
|
||||
YACY_QUERY_URL: Optional[str] = None
|
||||
YACY_USERNAME: Optional[str] = None
|
||||
@@ -627,6 +729,8 @@ class WebConfig(BaseModel):
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None
|
||||
EXA_API_KEY: Optional[str] = None
|
||||
PERPLEXITY_API_KEY: Optional[str] = None
|
||||
PERPLEXITY_MODEL: Optional[str] = None
|
||||
PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None
|
||||
SOUGOU_API_SID: Optional[str] = None
|
||||
SOUGOU_API_SK: Optional[str] = None
|
||||
WEB_LOADER_ENGINE: Optional[str] = None
|
||||
@@ -656,10 +760,20 @@ class ConfigForm(BaseModel):
|
||||
ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None
|
||||
TOP_K_RERANKER: Optional[int] = None
|
||||
RELEVANCE_THRESHOLD: Optional[float] = None
|
||||
HYBRID_BM25_WEIGHT: Optional[float] = None
|
||||
|
||||
# Content extraction settings
|
||||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||||
DATALAB_MARKER_API_KEY: Optional[str] = None
|
||||
DATALAB_MARKER_LANGS: Optional[str] = None
|
||||
DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None
|
||||
DATALAB_MARKER_FORCE_OCR: Optional[bool] = None
|
||||
DATALAB_MARKER_PAGINATE: Optional[bool] = None
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None
|
||||
DATALAB_MARKER_USE_LLM: Optional[bool] = None
|
||||
DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None
|
||||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
||||
|
||||
@@ -853,6 +967,11 @@ async def update_rag_config(
|
||||
if form_data.RELEVANCE_THRESHOLD is not None
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
)
|
||||
request.app.state.config.HYBRID_BM25_WEIGHT = (
|
||||
form_data.HYBRID_BM25_WEIGHT
|
||||
if form_data.HYBRID_BM25_WEIGHT is not None
|
||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||
)
|
||||
|
||||
# Content extraction settings
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||
@@ -865,6 +984,51 @@ async def update_rag_config(
|
||||
if form_data.PDF_EXTRACT_IMAGES is not None
|
||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_API_KEY = (
|
||||
form_data.DATALAB_MARKER_API_KEY
|
||||
if form_data.DATALAB_MARKER_API_KEY is not None
|
||||
else request.app.state.config.DATALAB_MARKER_API_KEY
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_LANGS = (
|
||||
form_data.DATALAB_MARKER_LANGS
|
||||
if form_data.DATALAB_MARKER_LANGS is not None
|
||||
else request.app.state.config.DATALAB_MARKER_LANGS
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_SKIP_CACHE = (
|
||||
form_data.DATALAB_MARKER_SKIP_CACHE
|
||||
if form_data.DATALAB_MARKER_SKIP_CACHE is not None
|
||||
else request.app.state.config.DATALAB_MARKER_SKIP_CACHE
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_FORCE_OCR = (
|
||||
form_data.DATALAB_MARKER_FORCE_OCR
|
||||
if form_data.DATALAB_MARKER_FORCE_OCR is not None
|
||||
else request.app.state.config.DATALAB_MARKER_FORCE_OCR
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_PAGINATE = (
|
||||
form_data.DATALAB_MARKER_PAGINATE
|
||||
if form_data.DATALAB_MARKER_PAGINATE is not None
|
||||
else request.app.state.config.DATALAB_MARKER_PAGINATE
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = (
|
||||
form_data.DATALAB_MARKER_STRIP_EXISTING_OCR
|
||||
if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None
|
||||
else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = (
|
||||
form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None
|
||||
else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = (
|
||||
form_data.DATALAB_MARKER_OUTPUT_FORMAT
|
||||
if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None
|
||||
else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
|
||||
)
|
||||
request.app.state.config.DATALAB_MARKER_USE_LLM = (
|
||||
form_data.DATALAB_MARKER_USE_LLM
|
||||
if form_data.DATALAB_MARKER_USE_LLM is not None
|
||||
else request.app.state.config.DATALAB_MARKER_USE_LLM
|
||||
)
|
||||
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
|
||||
form_data.EXTERNAL_DOCUMENT_LOADER_URL
|
||||
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
|
||||
@@ -1046,6 +1210,9 @@ async def update_rag_config(
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
|
||||
)
|
||||
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
|
||||
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
|
||||
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
|
||||
@@ -1082,6 +1249,10 @@ async def update_rag_config(
|
||||
)
|
||||
request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY
|
||||
request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY
|
||||
request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL
|
||||
request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = (
|
||||
form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE
|
||||
)
|
||||
request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID
|
||||
request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK
|
||||
|
||||
@@ -1132,9 +1303,19 @@ async def update_rag_config(
|
||||
"ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
"TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER,
|
||||
"RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD,
|
||||
# Content extraction settings
|
||||
"HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY,
|
||||
"DATALAB_MARKER_LANGS": request.app.state.config.DATALAB_MARKER_LANGS,
|
||||
"DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE,
|
||||
"DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR,
|
||||
"DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE,
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR,
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
|
||||
"DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM,
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT,
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||
@@ -1170,6 +1351,7 @@ async def update_rag_config(
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
@@ -1194,6 +1376,8 @@ async def update_rag_config(
|
||||
"BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"EXA_API_KEY": request.app.state.config.EXA_API_KEY,
|
||||
"PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY,
|
||||
"PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL,
|
||||
"PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
"SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID,
|
||||
"SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK,
|
||||
"WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE,
|
||||
@@ -1268,11 +1452,14 @@ def save_docs_to_vector_db(
|
||||
embedding_engine = rag_config.get("embedding_engine", request.app.state.config.RAG_EMBEDDING_ENGINE)
|
||||
embedding_model = rag_config.get("embedding_model", request.app.state.config.RAG_EMBEDDING_MODEL)
|
||||
embedding_batch_size = rag_config.get("embedding_batch_size", request.app.state.config.RAG_EMBEDDING_BATCH_SIZE)
|
||||
openai_api_base_url = rag_config.get("openai_api_base_url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
|
||||
openai_api_key = rag_config.get("openai_api_key", request.app.state.config.RAG_OPENAI_API_KEY)
|
||||
ollama_base_url = rag_config.get("ollama", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
||||
ollama_api_key = rag_config.get("ollama", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
||||
|
||||
openai_api_base_url = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_BASE_URL)
|
||||
openai_api_key = rag_config.get("openai_config", {}).get("url", request.app.state.config.RAG_OPENAI_API_KEY)
|
||||
ollama_base_url = rag_config.get("ollama_config", {}).get("url", request.app.state.config.RAG_OLLAMA_BASE_URL)
|
||||
ollama_api_key = rag_config.get("ollama_config", {}).get("key", request.app.state.config.RAG_OLLAMA_API_KEY)
|
||||
azure_openai_url = rag_config.get("azure_openai", {}).get("url", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
||||
azure_openai_key = rag_config.get("azure_openai", {}).get("key", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
||||
azure_openai_version = rag_config.get("azure_openai", {}).get("version", request.app.state.config.RAG_AZURE_OPENAI_BASE_URL)
|
||||
|
||||
# Check if entries with the same hash (metadata.hash) already exist
|
||||
if metadata and "hash" in metadata:
|
||||
result = VECTOR_DB_CLIENT.query(
|
||||
@@ -1354,20 +1541,29 @@ def save_docs_to_vector_db(
|
||||
|
||||
log.info(f"adding to collection {collection_name}")
|
||||
embedding_function = get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
request.app.state.ef.get(embedding_model, request.app.state.config.RAG_EMBEDDING_MODEL),
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||
request.app.state.ef,
|
||||
(
|
||||
openai_api_base_url
|
||||
if embedding_engine == "openai"
|
||||
else ollama_base_url
|
||||
else (
|
||||
ollama_base_url
|
||||
if embedding_engine == "ollama"
|
||||
else azure_openai_url
|
||||
)
|
||||
),
|
||||
(
|
||||
openai_api_key
|
||||
if embedding_engine == "openai"
|
||||
else ollama_api_key
|
||||
request.app.state.config.RAG_OPENAI_API_KEY
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||
else request.app.state.config.RAG_OLLAMA_API_KEY
|
||||
),
|
||||
embedding_batch_size,
|
||||
azure_api_version=(
|
||||
azure_openai_version
|
||||
if embedding_engine == "azure_openai"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
embeddings = embedding_function(
|
||||
@@ -1439,6 +1635,33 @@ def process_file(
|
||||
content_extraction_engine = rag_config.get(
|
||||
"CONTENT_EXTRACTION_ENGINE", request.app.state.config.CONTENT_EXTRACTION_ENGINE
|
||||
)
|
||||
datalab_marker_api_key=rag_config.get(
|
||||
"DATALAB_MARKER_API_KEY", request.app.state.config.DATALAB_MARKER_API_KEY
|
||||
)
|
||||
datalab_marker_langs=rag_config.get(
|
||||
"DATALAB_MARKER_LANGS", request.app.state.config.DATALAB_MARKER_LANGS
|
||||
)
|
||||
datalab_marker_skip_cache=rag_config.get(
|
||||
"DATALAB_MARKER_SKIP_CACHE", request.app.state.config.DATALAB_MARKER_SKIP_CACHE
|
||||
)
|
||||
datalab_marker_force_ocr=rag_config.get(
|
||||
"DATALAB_MARKER_FORCE_OCR", request.app.state.config.DATALAB_MARKER_FORCE_OCR
|
||||
)
|
||||
datalab_marker_paginate=rag_config.get(
|
||||
"DATALAB_MARKER_PAGINATE", request.app.state.config.DATALAB_MARKER_PAGINATE
|
||||
)
|
||||
datalab_marker_strip_existing_ocr=rag_config.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR
|
||||
)
|
||||
datalab_marker_disable_image_extraction=rag_config.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION
|
||||
)
|
||||
datalab_marker_use_llm=rag_config.get(
|
||||
"DATALAB_MARKER_USE_LLM", request.app.state.config.DATALAB_MARKER_USE_LLM
|
||||
)
|
||||
datalab_marker_output_format=rag_config.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT
|
||||
)
|
||||
external_document_loader_url = rag_config.get(
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL", request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
||||
)
|
||||
@@ -1537,6 +1760,15 @@ def process_file(
|
||||
file_path = Storage.get_file(file_path)
|
||||
loader = Loader(
|
||||
engine=content_extraction_engine,
|
||||
DATALAB_MARKER_API_KEY=datalab_marker_api_key,
|
||||
DATALAB_MARKER_LANGS=datalab_marker_langs,
|
||||
DATALAB_MARKER_SKIP_CACHE=datalab_marker_skip_cache,
|
||||
DATALAB_MARKER_FORCE_OCR=datalab_marker_force_ocr,
|
||||
DATALAB_MARKER_PAGINATE=datalab_marker_paginate,
|
||||
DATALAB_MARKER_STRIP_EXISTING_OCR=datalab_marker_strip_existing_ocr,
|
||||
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=datalab_marker_disable_image_extraction,
|
||||
DATALAB_MARKER_USE_LLM=datalab_marker_use_llm,
|
||||
DATALAB_MARKER_OUTPUT_FORMAT=datalab_marker_output_format,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL=external_document_loader_url,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=external_document_loader_api_key,
|
||||
TIKA_SERVER_URL=tika_server_url,
|
||||
@@ -1961,19 +2193,14 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "exa":
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "perplexity":
|
||||
return search_perplexity(
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
model=request.app.state.config.PERPLEXITY_MODEL,
|
||||
search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE,
|
||||
)
|
||||
elif engine == "sougou":
|
||||
if (
|
||||
@@ -2052,13 +2279,29 @@ async def process_web_search(
|
||||
)
|
||||
|
||||
try:
|
||||
loader = get_web_loader(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.snippet,
|
||||
metadata={
|
||||
"source": result.link,
|
||||
"title": result.title,
|
||||
"snippet": result.snippet,
|
||||
"link": result.link,
|
||||
},
|
||||
)
|
||||
for result in search_results
|
||||
if hasattr(result, "snippet")
|
||||
]
|
||||
else:
|
||||
loader = get_web_loader(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
|
||||
urls = [
|
||||
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
||||
] # only keep the urls returned by the loader
|
||||
@@ -2148,6 +2391,11 @@ def query_doc_handler(
|
||||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
hybrid_bm25_weight=(
|
||||
form_data.hybrid_bm25_weight
|
||||
if form_data.hybrid_bm25_weight
|
||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||
),
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
@@ -2174,6 +2422,7 @@ class QueryCollectionsForm(BaseModel):
|
||||
k_reranker: Optional[int] = None
|
||||
r: Optional[float] = None
|
||||
hybrid: Optional[bool] = None
|
||||
hybrid_bm25_weight: Optional[float] = None
|
||||
|
||||
|
||||
@router.post("/query/collection")
|
||||
@@ -2199,6 +2448,11 @@ def query_collection_handler(
|
||||
if form_data.r
|
||||
else request.app.state.config.RELEVANCE_THRESHOLD
|
||||
),
|
||||
hybrid_bm25_weight=(
|
||||
form_data.hybrid_bm25_weight
|
||||
if form_data.hybrid_bm25_weight
|
||||
else request.app.state.config.HYBRID_BM25_WEIGHT
|
||||
),
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
|
||||
@@ -9,6 +9,7 @@ import re
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.task import (
|
||||
title_generation_template,
|
||||
follow_up_generation_template,
|
||||
query_generation_template,
|
||||
image_prompt_generation_template,
|
||||
autocomplete_generation_template,
|
||||
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
|
||||
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
||||
FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
|
||||
ENABLE_FOLLOW_UP_GENERATION: bool
|
||||
ENABLE_TAGS_GENERATION: bool
|
||||
ENABLE_SEARCH_QUERY_GENERATION: bool
|
||||
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
||||
@@ -94,6 +100,13 @@ async def update_task_config(
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
|
||||
form_data.ENABLE_FOLLOW_UP_GENERATION
|
||||
)
|
||||
request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
@@ -133,6 +146,8 @@ async def update_task_config(
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
||||
"ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
||||
"FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
||||
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
||||
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
||||
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
@@ -231,6 +246,86 @@ async def generate_title(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/follow_up/completions")
|
||||
async def generate_follow_ups(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
|
||||
if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"detail": "Follow-up generation is disabled"},
|
||||
)
|
||||
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
request.app.state.config.TASK_MODEL,
|
||||
request.app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
f"generating chat title using model {task_model_id} for user {user.email} "
|
||||
)
|
||||
|
||||
if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
else:
|
||||
template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = follow_up_generation_template(
|
||||
template,
|
||||
form_data["messages"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"metadata": {
|
||||
**(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
||||
"task": str(TASKS.FOLLOW_UP_GENERATION),
|
||||
"task_body": form_data,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
log.error("Exception occurred", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": "An internal error has occurred."},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tags/completions")
|
||||
async def generate_chat_tags(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
@@ -2,6 +2,9 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import time
|
||||
import re
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from open_webui.models.tools import (
|
||||
ToolForm,
|
||||
@@ -21,6 +24,7 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
@@ -51,11 +55,11 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
**{
|
||||
"id": f"server:{server['idx']}",
|
||||
"user_id": f"server:{server['idx']}",
|
||||
"name": server["openapi"]
|
||||
"name": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("title", "Tool Server"),
|
||||
"meta": {
|
||||
"description": server["openapi"]
|
||||
"description": server.get("openapi", {})
|
||||
.get("info", {})
|
||||
.get("description", ""),
|
||||
},
|
||||
@@ -95,6 +99,81 @@ async def get_tool_list(user=Depends(get_verified_user)):
|
||||
return tools
|
||||
|
||||
|
||||
############################
|
||||
# LoadFunctionFromLink
|
||||
############################
|
||||
|
||||
|
||||
class LoadUrlForm(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
def github_url_to_raw_url(url: str) -> str:
|
||||
# Handle 'tree' (folder) URLs (add main.py at the end)
|
||||
m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
|
||||
if m1:
|
||||
org, repo, branch, path = m1.groups()
|
||||
return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
|
||||
|
||||
# Handle 'blob' (file) URLs
|
||||
m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
|
||||
if m2:
|
||||
org, repo, branch, path = m2.groups()
|
||||
return (
|
||||
f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
|
||||
)
|
||||
|
||||
# No match; return as-is
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/load/url", response_model=Optional[dict])
|
||||
async def load_tool_from_url(
|
||||
request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
|
||||
):
|
||||
# NOTE: This is NOT a SSRF vulnerability:
|
||||
# This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
|
||||
# and does NOT accept untrusted user input. Access is enforced by authentication.
|
||||
|
||||
url = str(form_data.url)
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Please enter a valid URL")
|
||||
|
||||
url = github_url_to_raw_url(url)
|
||||
url_parts = url.rstrip("/").split("/")
|
||||
|
||||
file_name = url_parts[-1]
|
||||
tool_name = (
|
||||
file_name[:-3]
|
||||
if (
|
||||
file_name.endswith(".py")
|
||||
and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
|
||||
)
|
||||
else url_parts[-2] if len(url_parts) > 1 else "function"
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=resp.status, detail="Failed to fetch the tool"
|
||||
)
|
||||
data = await resp.text()
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No data received from the URL"
|
||||
)
|
||||
return {
|
||||
"name": tool_name,
|
||||
"content": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
|
||||
|
||||
|
||||
############################
|
||||
# ExportTools
|
||||
############################
|
||||
|
||||
@@ -165,22 +165,6 @@ async def update_default_user_permissions(
|
||||
return request.app.state.config.USER_PERMISSIONS
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserRole
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserSettingsBySessionUser
|
||||
############################
|
||||
@@ -333,11 +317,22 @@ async def update_user_by_id(
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
if first_user and user_id == first_user.id and session_user.id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
if first_user:
|
||||
if user_id == first_user.id:
|
||||
if session_user.id != user_id:
|
||||
# If the user trying to update is the primary admin, and they are not the primary admin themselves
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
if form_data.role != "admin":
|
||||
# If the primary admin is trying to change their own role, prevent it
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error checking primary admin status: {e}")
|
||||
raise HTTPException(
|
||||
@@ -365,6 +360,7 @@ async def update_user_by_id(
|
||||
updated_user = Users.update_user_by_id(
|
||||
user_id,
|
||||
{
|
||||
"role": form_data.role,
|
||||
"name": form_data.name,
|
||||
"email": form_data.email.lower(),
|
||||
"profile_image_url": form_data.profile_image_url,
|
||||
|
||||
@@ -269,11 +269,6 @@ tbody + tbody {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Add a rule to reset margin-bottom for <p> not followed by <ul> */
|
||||
.markdown-section p + ul {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
/* List item styles */
|
||||
.markdown-section li {
|
||||
padding: 2px;
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import shutil
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, Tuple, Dict
|
||||
|
||||
@@ -136,6 +137,11 @@ class S3StorageProvider(StorageProvider):
|
||||
self.bucket_name = S3_BUCKET_NAME
|
||||
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
|
||||
|
||||
@staticmethod
|
||||
def sanitize_tag_value(s: str) -> str:
|
||||
"""Only include S3 allowed characters."""
|
||||
return re.sub(r"[^a-zA-Z0-9 äöüÄÖÜß\+\-=\._:/@]", "", s)
|
||||
|
||||
def upload_file(
|
||||
self, file: BinaryIO, filename: str, tags: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
@@ -145,7 +151,15 @@ class S3StorageProvider(StorageProvider):
|
||||
try:
|
||||
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
|
||||
if S3_ENABLE_TAGGING and tags:
|
||||
tagging = {"TagSet": [{"Key": k, "Value": v} for k, v in tags.items()]}
|
||||
sanitized_tags = {
|
||||
self.sanitize_tag_value(k): self.sanitize_tag_value(v)
|
||||
for k, v in tags.items()
|
||||
}
|
||||
tagging = {
|
||||
"TagSet": [
|
||||
{"Key": k, "Value": v} for k, v in sanitized_tags.items()
|
||||
]
|
||||
}
|
||||
self.s3_client.put_object_tagging(
|
||||
Bucket=self.bucket_name,
|
||||
Key=s3_key,
|
||||
|
||||
@@ -40,7 +40,10 @@ from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.models import get_all_models, check_model_access
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
@@ -317,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -392,8 +390,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
}
|
||||
)
|
||||
|
||||
function_module, _, _ = load_function_module_by_id(action_id)
|
||||
request.app.state.FUNCTIONS[action_id] = function_module
|
||||
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
@@ -422,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
|
||||
90
backend/open_webui/utils/embeddings.py
Normal file
90
backend/open_webui/utils/embeddings.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from fastapi import Request
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.utils.models import check_model_access
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
from open_webui.routers.openai import embeddings as openai_embeddings
|
||||
from open_webui.routers.ollama import (
|
||||
embeddings as ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
||||
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: UserModel,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
"""
|
||||
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): The input data sent to the endpoint.
|
||||
user (UserModel): The authenticated user.
|
||||
bypass_filter (bool): If True, disables access filtering (default False).
|
||||
|
||||
Returns:
|
||||
dict: The embeddings response, following OpenAI API compatibility.
|
||||
"""
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
# Attach extra metadata from request.state if present
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
# If "direct" flag present, use only that model
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data.get("model")
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
# Access filtering
|
||||
if not getattr(request.state, "direct", False):
|
||||
if not bypass_filter and user.role == "user":
|
||||
check_model_access(user, model)
|
||||
|
||||
# Ollama backend
|
||||
if model.get("owned_by") == "ollama":
|
||||
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
||||
response = await ollama_embeddings(
|
||||
request=request,
|
||||
form_data=GenerateEmbeddingsForm(**ollama_payload),
|
||||
user=user,
|
||||
)
|
||||
return convert_embedding_response_ollama_to_openai(response)
|
||||
|
||||
# Default: OpenAI or compatible backend
|
||||
return await openai_embeddings(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
)
|
||||
@@ -1,7 +1,10 @@
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -9,14 +12,13 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module(request, function_id):
|
||||
def get_function_module(request, function_id, load_from_db=True):
|
||||
"""
|
||||
Get the function module by its ID.
|
||||
"""
|
||||
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
|
||||
function_module, _, _ = get_function_module_from_cache(
|
||||
request, function_id, load_from_db
|
||||
)
|
||||
return function_module
|
||||
|
||||
|
||||
@@ -37,14 +39,17 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
for filter_id in active_filter_ids:
|
||||
def get_active_status(filter_id):
|
||||
function_module = get_function_module(request, filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None) and (
|
||||
filter_id not in enabled_filter_ids
|
||||
):
|
||||
active_filter_ids.remove(filter_id)
|
||||
continue
|
||||
if getattr(function_module, "toggle", None):
|
||||
return filter_id in (enabled_filter_ids or [])
|
||||
|
||||
return True
|
||||
|
||||
active_filter_ids = [
|
||||
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
@@ -63,7 +68,9 @@ async def process_filter_functions(
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
function_module = get_function_module(request, filter_id)
|
||||
function_module = get_function_module(
|
||||
request, filter_id, load_from_db=(filter_type != "stream")
|
||||
)
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
|
||||
@@ -32,6 +32,7 @@ from open_webui.socket.main import (
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_follow_ups,
|
||||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
@@ -41,6 +42,7 @@ from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
process_pipeline_outlet_filter,
|
||||
)
|
||||
from open_webui.routers.memories import query_memory, QueryMemoryForm
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
@@ -251,7 +253,12 @@ async def chat_completion_tools_handler(
|
||||
"name": (f"TOOL:{tool_name}"),
|
||||
},
|
||||
"document": [tool_result],
|
||||
"metadata": [{"source": (f"TOOL:{tool_name}")}],
|
||||
"metadata": [
|
||||
{
|
||||
"source": (f"TOOL:{tool_name}"),
|
||||
"parameters": tool_function_params,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -290,6 +297,45 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_memory_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
try:
|
||||
results = await query_memory(
|
||||
request,
|
||||
QueryMemoryForm(
|
||||
**{
|
||||
"content": get_last_user_message(form_data["messages"]) or "",
|
||||
"k": 3,
|
||||
}
|
||||
),
|
||||
user,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
results = None
|
||||
|
||||
user_context = ""
|
||||
if results and hasattr(results, "documents"):
|
||||
if results.documents and len(results.documents) > 0:
|
||||
for doc_idx, doc in enumerate(results.documents[0]):
|
||||
created_at_date = "Unknown Date"
|
||||
|
||||
if results.metadatas[0][doc_idx].get("created_at"):
|
||||
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
|
||||
created_at_date = time.strftime(
|
||||
"%Y-%m-%d", time.localtime(created_at_timestamp)
|
||||
)
|
||||
|
||||
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
@@ -389,6 +435,7 @@ async def chat_web_search_handler(
|
||||
"name": ", ".join(queries),
|
||||
"type": "web_search",
|
||||
"urls": results["filenames"],
|
||||
"queries": queries,
|
||||
}
|
||||
)
|
||||
elif results.get("docs"):
|
||||
@@ -400,6 +447,7 @@ async def chat_web_search_handler(
|
||||
"name": ", ".join(queries),
|
||||
"type": "web_search",
|
||||
"urls": results["filenames"],
|
||||
"queries": queries,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -617,6 +665,7 @@ async def chat_completion_files_handler(
|
||||
reranking_function=reranking_function,
|
||||
k_reranker=k_reranker,
|
||||
r=r,
|
||||
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
|
||||
hybrid_search=hybrid_search,
|
||||
full_context=full_context,
|
||||
),
|
||||
@@ -631,6 +680,32 @@ async def chat_completion_files_handler(
|
||||
|
||||
def apply_params_to_form_data(form_data, model):
|
||||
params = form_data.pop("params", {})
|
||||
custom_params = params.pop("custom_params", {})
|
||||
|
||||
open_webui_params = {
|
||||
"stream_response": bool,
|
||||
"function_calling": str,
|
||||
"system": str,
|
||||
}
|
||||
|
||||
for key in list(params.keys()):
|
||||
if key in open_webui_params:
|
||||
del params[key]
|
||||
|
||||
if custom_params:
|
||||
# Attempt to parse custom_params if they are strings
|
||||
for key, value in custom_params.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
# Attempt to parse the string as JSON
|
||||
custom_params[key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# If it fails, keep the original string
|
||||
pass
|
||||
|
||||
# If custom_params are provided, merge them into params
|
||||
params = deep_update(params, custom_params)
|
||||
|
||||
if model.get("ollama"):
|
||||
form_data["options"] = params
|
||||
|
||||
@@ -640,29 +715,10 @@ def apply_params_to_form_data(form_data, model):
|
||||
if "keep_alive" in params:
|
||||
form_data["keep_alive"] = params["keep_alive"]
|
||||
else:
|
||||
if "seed" in params and params["seed"] is not None:
|
||||
form_data["seed"] = params["seed"]
|
||||
|
||||
if "stop" in params and params["stop"] is not None:
|
||||
form_data["stop"] = params["stop"]
|
||||
|
||||
if "temperature" in params and params["temperature"] is not None:
|
||||
form_data["temperature"] = params["temperature"]
|
||||
|
||||
if "max_tokens" in params and params["max_tokens"] is not None:
|
||||
form_data["max_tokens"] = params["max_tokens"]
|
||||
|
||||
if "top_p" in params and params["top_p"] is not None:
|
||||
form_data["top_p"] = params["top_p"]
|
||||
|
||||
if "frequency_penalty" in params and params["frequency_penalty"] is not None:
|
||||
form_data["frequency_penalty"] = params["frequency_penalty"]
|
||||
|
||||
if "presence_penalty" in params and params["presence_penalty"] is not None:
|
||||
form_data["presence_penalty"] = params["presence_penalty"]
|
||||
|
||||
if "reasoning_effort" in params and params["reasoning_effort"] is not None:
|
||||
form_data["reasoning_effort"] = params["reasoning_effort"]
|
||||
if isinstance(params, dict):
|
||||
for key, value in params.items():
|
||||
if value is not None:
|
||||
form_data[key] = value
|
||||
|
||||
if "logit_bias" in params and params["logit_bias"] is not None:
|
||||
try:
|
||||
@@ -676,7 +732,6 @@ def apply_params_to_form_data(form_data, model):
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
@@ -686,12 +741,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -788,6 +838,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "memory" in features and features["memory"]:
|
||||
form_data = await chat_memory_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
@@ -890,6 +945,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
source_name = source.get("source", {}).get("name", None)
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
@@ -897,7 +953,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
|
||||
context_string += (
|
||||
f'<source id="{citation_idx[citation_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
+ f">{doc_context}</source>\n"
|
||||
)
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -968,7 +1028,7 @@ async def process_chat_response(
|
||||
message = message_map.get(metadata["message_id"]) if message_map else None
|
||||
|
||||
if message:
|
||||
message_list = get_message_list(message_map, message.get("id"))
|
||||
message_list = get_message_list(message_map, metadata["message_id"])
|
||||
|
||||
# Remove details tags and files from the messages.
|
||||
# as get_message_list creates a new list, it does not affect
|
||||
@@ -985,7 +1045,7 @@ async def process_chat_response(
|
||||
|
||||
if isinstance(content, str):
|
||||
content = re.sub(
|
||||
r"<details\b[^>]*>.*?<\/details>",
|
||||
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
|
||||
"",
|
||||
content,
|
||||
flags=re.S | re.I,
|
||||
@@ -994,12 +1054,67 @@ async def process_chat_response(
|
||||
messages.append(
|
||||
{
|
||||
**message,
|
||||
"role": message["role"],
|
||||
"role": message.get(
|
||||
"role", "assistant"
|
||||
), # Safe fallback for missing role
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
if tasks and messages:
|
||||
if (
|
||||
TASKS.FOLLOW_UP_GENERATION in tasks
|
||||
and tasks[TASKS.FOLLOW_UP_GENERATION]
|
||||
):
|
||||
res = await generate_follow_ups(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"message_id": metadata["message_id"],
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res and isinstance(res, dict):
|
||||
if len(res.get("choices", [])) == 1:
|
||||
follow_ups_string = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
else:
|
||||
follow_ups_string = ""
|
||||
|
||||
follow_ups_string = follow_ups_string[
|
||||
follow_ups_string.find("{") : follow_ups_string.rfind("}")
|
||||
+ 1
|
||||
]
|
||||
|
||||
try:
|
||||
follow_ups = json.loads(follow_ups_string).get(
|
||||
"follow_ups", []
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"followUps": follow_ups,
|
||||
},
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:follow_ups",
|
||||
"data": {
|
||||
"follow_ups": follow_ups,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
if tasks[TASKS.TITLE_GENERATION]:
|
||||
res = await generate_title(
|
||||
@@ -1162,6 +1277,7 @@ async def process_chat_response(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
@@ -1184,8 +1300,34 @@ async def process_chat_response(
|
||||
|
||||
await background_tasks_handler()
|
||||
|
||||
if events and isinstance(events, list) and isinstance(response, dict):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
response = {
|
||||
**extra_response,
|
||||
**response,
|
||||
}
|
||||
|
||||
return response
|
||||
else:
|
||||
if events and isinstance(events, list) and isinstance(response, dict):
|
||||
extra_response = {}
|
||||
for event in events:
|
||||
if isinstance(event, dict):
|
||||
extra_response.update(event)
|
||||
else:
|
||||
extra_response[event] = True
|
||||
|
||||
response = {
|
||||
**extra_response,
|
||||
**response,
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
# Non standard response
|
||||
@@ -1198,12 +1340,7 @@ async def process_chat_response(
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
|
||||
@@ -34,11 +34,15 @@ def get_message_list(messages, message_id):
|
||||
:return: List of ordered messages starting from the root to the given message
|
||||
"""
|
||||
|
||||
# Handle case where messages is None
|
||||
if not messages:
|
||||
return [] # Return empty list instead of None to prevent iteration errors
|
||||
|
||||
# Find the message by its id
|
||||
current_message = messages.get(message_id)
|
||||
|
||||
if not current_message:
|
||||
return None
|
||||
return [] # Return empty list instead of None to prevent iteration errors
|
||||
|
||||
# Reconstruct the chain by following the parentId links
|
||||
message_list = []
|
||||
@@ -47,7 +51,7 @@ def get_message_list(messages, message_id):
|
||||
message_list.insert(
|
||||
0, current_message
|
||||
) # Insert the message at the beginning of the list
|
||||
parent_id = current_message["parentId"]
|
||||
parent_id = current_message.get("parentId") # Use .get() for safety
|
||||
current_message = messages.get(parent_id) if parent_id else None
|
||||
|
||||
return message_list
|
||||
@@ -70,12 +74,12 @@ def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
|
||||
|
||||
|
||||
def get_content_from_message(message: dict) -> Optional[str]:
|
||||
if isinstance(message["content"], list):
|
||||
if isinstance(message.get("content"), list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
return item["text"]
|
||||
else:
|
||||
return message["content"]
|
||||
return message.get("content")
|
||||
return None
|
||||
|
||||
|
||||
@@ -130,7 +134,9 @@ def prepend_to_first_user_message_content(
|
||||
return messages
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
def add_or_update_system_message(
|
||||
content: str, messages: list[dict], append: bool = False
|
||||
):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
or updates the existing system message at the beginning.
|
||||
@@ -141,7 +147,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
if append:
|
||||
messages[0]["content"] = f"{messages[0]['content']}\n{content}"
|
||||
else:
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
else:
|
||||
# Insert at the beginning
|
||||
messages.insert(0, {"role": "system", "content": content})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import logging
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from aiocache import cached
|
||||
@@ -13,7 +14,10 @@ from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
@@ -30,35 +34,46 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def fetch_ollama_models(request: Request, user: UserModel = None):
|
||||
raw_ollama_models = await ollama.get_all_models(request, user=user)
|
||||
return [
|
||||
{
|
||||
"id": model["model"],
|
||||
"name": model["name"],
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "ollama",
|
||||
"ollama": model,
|
||||
"connection_type": model.get("connection_type", "local"),
|
||||
"tags": model.get("tags", []),
|
||||
}
|
||||
for model in raw_ollama_models["models"]
|
||||
]
|
||||
|
||||
|
||||
async def fetch_openai_models(request: Request, user: UserModel = None):
|
||||
openai_response = await openai.get_all_models(request, user=user)
|
||||
return openai_response["data"]
|
||||
|
||||
|
||||
async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
function_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
openai_task = (
|
||||
fetch_openai_models(request, user)
|
||||
if request.app.state.config.ENABLE_OPENAI_API
|
||||
else asyncio.sleep(0, result=[])
|
||||
)
|
||||
ollama_task = (
|
||||
fetch_ollama_models(request, user)
|
||||
if request.app.state.config.ENABLE_OLLAMA_API
|
||||
else asyncio.sleep(0, result=[])
|
||||
)
|
||||
function_task = get_function_models(request)
|
||||
|
||||
if request.app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await openai.get_all_models(request, user=user)
|
||||
openai_models = openai_models["data"]
|
||||
openai_models, ollama_models, function_models = await asyncio.gather(
|
||||
openai_task, ollama_task, function_task
|
||||
)
|
||||
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await ollama.get_all_models(request, user=user)
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
"name": model["name"],
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "ollama",
|
||||
"ollama": model,
|
||||
"connection_type": model.get("connection_type", "local"),
|
||||
"tags": model.get("tags", []),
|
||||
}
|
||||
for model in ollama_models["models"]
|
||||
]
|
||||
|
||||
function_models = await get_function_models(request)
|
||||
models = function_models + openai_models + ollama_models
|
||||
|
||||
return models
|
||||
return function_models + openai_models + ollama_models
|
||||
|
||||
|
||||
async def get_all_models(request, user: UserModel = None):
|
||||
@@ -239,8 +254,7 @@ async def get_all_models(request, user: UserModel = None):
|
||||
]
|
||||
|
||||
def get_function_module_by_id(function_id):
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
function_module, _, _ = get_function_module_from_cache(request, function_id)
|
||||
return function_module
|
||||
|
||||
for model in models:
|
||||
|
||||
@@ -536,5 +536,10 @@ class OAuthManager:
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
|
||||
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
|
||||
if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from open_webui.utils.task import prompt_template, prompt_variables_template
|
||||
from open_webui.utils.misc import (
|
||||
deep_update,
|
||||
add_or_update_system_message,
|
||||
)
|
||||
|
||||
@@ -9,9 +10,8 @@ import json
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(
|
||||
params: dict, form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None
|
||||
) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
@@ -45,18 +45,64 @@ def apply_model_params_to_body(
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
for key, value in params.items():
|
||||
if value is not None:
|
||||
if key in mappings:
|
||||
cast_func = mappings[key]
|
||||
if isinstance(cast_func, Callable):
|
||||
form_data[key] = cast_func(value)
|
||||
else:
|
||||
form_data[key] = value
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def remove_open_webui_params(params: dict) -> dict:
|
||||
"""
|
||||
Removes OpenWebUI specific parameters from the provided dictionary.
|
||||
|
||||
Args:
|
||||
params (dict): The dictionary containing parameters.
|
||||
|
||||
Returns:
|
||||
dict: The modified dictionary with OpenWebUI parameters removed.
|
||||
"""
|
||||
open_webui_params = {
|
||||
"stream_response": bool,
|
||||
"function_calling": str,
|
||||
"system": str,
|
||||
}
|
||||
|
||||
for key in list(params.keys()):
|
||||
if key in open_webui_params:
|
||||
del params[key]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
params = remove_open_webui_params(params)
|
||||
|
||||
custom_params = params.pop("custom_params", {})
|
||||
if custom_params:
|
||||
# Attempt to parse custom_params if they are strings
|
||||
for key, value in custom_params.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
# Attempt to parse the string as JSON
|
||||
custom_params[key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# If it fails, keep the original string
|
||||
pass
|
||||
|
||||
# If there are custom parameters, we need to apply them first
|
||||
params = deep_update(params, custom_params)
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": float,
|
||||
"min_p": float,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": float,
|
||||
"presence_penalty": float,
|
||||
@@ -70,6 +116,23 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
|
||||
|
||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
params = remove_open_webui_params(params)
|
||||
|
||||
custom_params = params.pop("custom_params", {})
|
||||
if custom_params:
|
||||
# Attempt to parse custom_params if they are strings
|
||||
for key, value in custom_params.items():
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
# Attempt to parse the string as JSON
|
||||
custom_params[key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
# If it fails, keep the original string
|
||||
pass
|
||||
|
||||
# If there are custom parameters, we need to apply them first
|
||||
params = deep_update(params, custom_params)
|
||||
|
||||
# Convert OpenAI parameter names to Ollama parameter names if needed.
|
||||
name_differences = {
|
||||
"max_tokens": "num_predict",
|
||||
@@ -266,3 +329,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
ollama_payload["format"] = format
|
||||
|
||||
return ollama_payload
|
||||
|
||||
|
||||
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"""
|
||||
Convert an embeddings request payload from OpenAI format to Ollama format.
|
||||
|
||||
Args:
|
||||
openai_payload (dict): The original payload designed for OpenAI API usage.
|
||||
|
||||
Returns:
|
||||
dict: A payload compatible with the Ollama API embeddings endpoint.
|
||||
"""
|
||||
ollama_payload = {"model": openai_payload.get("model")}
|
||||
input_value = openai_payload.get("input")
|
||||
|
||||
# Ollama expects 'input' as a list, and 'prompt' as a single string.
|
||||
if isinstance(input_value, list):
|
||||
ollama_payload["input"] = input_value
|
||||
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
|
||||
else:
|
||||
ollama_payload["input"] = [input_value]
|
||||
ollama_payload["prompt"] = str(input_value)
|
||||
|
||||
# Optionally forward other fields if present
|
||||
for optional_key in ("options", "truncate", "keep_alive"):
|
||||
if optional_key in openai_payload:
|
||||
ollama_payload[optional_key] = openai_payload[optional_key]
|
||||
|
||||
return ollama_payload
|
||||
|
||||
@@ -115,7 +115,7 @@ def load_tool_module_by_id(tool_id, content=None):
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
def load_function_module_by_id(function_id, content=None):
|
||||
def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||
if content is None:
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if not function:
|
||||
@@ -166,6 +166,62 @@ def load_function_module_by_id(function_id, content=None):
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||
if load_from_db:
|
||||
# Always load from the database by default
|
||||
# This is useful for hooks like "inlet" or "outlet" where the content might change
|
||||
# and we want to ensure the latest content is used.
|
||||
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if not function:
|
||||
raise Exception(f"Function not found: {function_id}")
|
||||
content = function.content
|
||||
|
||||
new_content = replace_imports(content)
|
||||
if new_content != content:
|
||||
content = new_content
|
||||
# Update the function content in the database
|
||||
Functions.update_function_by_id(function_id, {"content": content})
|
||||
|
||||
if (
|
||||
hasattr(request.app.state, "FUNCTION_CONTENTS")
|
||||
and function_id in request.app.state.FUNCTION_CONTENTS
|
||||
) and (
|
||||
hasattr(request.app.state, "FUNCTIONS")
|
||||
and function_id in request.app.state.FUNCTIONS
|
||||
):
|
||||
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
|
||||
return request.app.state.FUNCTIONS[function_id], None, None
|
||||
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function_id, content
|
||||
)
|
||||
else:
|
||||
# Load from cache (e.g. "stream" hook)
|
||||
# This is useful for performance reasons
|
||||
|
||||
if (
|
||||
hasattr(request.app.state, "FUNCTIONS")
|
||||
and function_id in request.app.state.FUNCTIONS
|
||||
):
|
||||
return request.app.state.FUNCTIONS[function_id], None, None
|
||||
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function_id
|
||||
)
|
||||
|
||||
if not hasattr(request.app.state, "FUNCTIONS"):
|
||||
request.app.state.FUNCTIONS = {}
|
||||
|
||||
if not hasattr(request.app.state, "FUNCTION_CONTENTS"):
|
||||
request.app.state.FUNCTION_CONTENTS = {}
|
||||
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
request.app.state.FUNCTION_CONTENTS[function_id] = content
|
||||
|
||||
return function_module, function_type, frontmatter
|
||||
|
||||
|
||||
def install_frontmatter_requirements(requirements: str):
|
||||
if requirements:
|
||||
try:
|
||||
|
||||
@@ -125,3 +125,64 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
yield line
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def convert_embedding_response_ollama_to_openai(response) -> dict:
|
||||
"""
|
||||
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
|
||||
|
||||
Args:
|
||||
response (dict): The response from the Ollama API,
|
||||
e.g. {"embedding": [...], "model": "..."}
|
||||
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
|
||||
|
||||
Returns:
|
||||
dict: Response adapted to OpenAI's embeddings API format.
|
||||
e.g. {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "embedding": [...], "index": 0},
|
||||
...
|
||||
],
|
||||
"model": "...",
|
||||
}
|
||||
"""
|
||||
# Ollama batch-style output
|
||||
if isinstance(response, dict) and "embeddings" in response:
|
||||
openai_data = []
|
||||
for i, emb in enumerate(response["embeddings"]):
|
||||
openai_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb.get("embedding"),
|
||||
"index": emb.get("index", i),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": openai_data,
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Ollama single output
|
||||
elif isinstance(response, dict) and "embedding" in response:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": response["embedding"],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Already OpenAI-compatible?
|
||||
elif (
|
||||
isinstance(response, dict)
|
||||
and "data" in response
|
||||
and isinstance(response["data"], list)
|
||||
):
|
||||
return response
|
||||
|
||||
# Fallback: return as is if unrecognized
|
||||
return response
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_task_model_id(
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if models[task_model_id].get("connection_type") == "local":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
@@ -207,6 +207,24 @@ def title_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def follow_up_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
@@ -160,7 +160,7 @@ def get_tools(
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
if val.get("type") == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal reserved parameters (e.g. __id__, __user__)
|
||||
@@ -490,8 +490,19 @@ async def get_tool_servers_data(
|
||||
server_entries = []
|
||||
for idx, server in enumerate(servers):
|
||||
if server.get("config", {}).get("enable"):
|
||||
url_path = server.get("path", "openapi.json")
|
||||
full_url = f"{server.get('url')}/{url_path}"
|
||||
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
|
||||
openapi_path = server.get("path", "openapi.json")
|
||||
if "://" in openapi_path:
|
||||
# If it contains "://", it's a full URL
|
||||
full_url = openapi_path
|
||||
else:
|
||||
if not openapi_path.startswith("/"):
|
||||
# Ensure the path starts with a slash
|
||||
openapi_path = f"/{openapi_path}"
|
||||
|
||||
full_url = f"{server.get('url')}{openapi_path}"
|
||||
|
||||
info = server.get("info", {})
|
||||
|
||||
auth_type = server.get("auth_type", "bearer")
|
||||
token = None
|
||||
@@ -500,26 +511,37 @@ async def get_tool_servers_data(
|
||||
token = server.get("key", "")
|
||||
elif auth_type == "session":
|
||||
token = session_token
|
||||
server_entries.append((idx, server, full_url, token))
|
||||
server_entries.append((idx, server, full_url, info, token))
|
||||
|
||||
# Create async tasks to fetch data
|
||||
tasks = [get_tool_server_data(token, url) for (_, _, url, token) in server_entries]
|
||||
tasks = [
|
||||
get_tool_server_data(token, url) for (_, _, url, _, token) in server_entries
|
||||
]
|
||||
|
||||
# Execute tasks concurrently
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Build final results with index and server metadata
|
||||
results = []
|
||||
for (idx, server, url, _), response in zip(server_entries, responses):
|
||||
for (idx, server, url, info, _), response in zip(server_entries, responses):
|
||||
if isinstance(response, Exception):
|
||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||
continue
|
||||
|
||||
openapi_data = response.get("openapi", {})
|
||||
|
||||
if info and isinstance(openapi_data, dict):
|
||||
if "name" in info:
|
||||
openapi_data["info"]["title"] = info.get("name", "Tool Server")
|
||||
|
||||
if "description" in info:
|
||||
openapi_data["info"]["description"] = info.get("description", "")
|
||||
|
||||
results.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"url": server.get("url"),
|
||||
"openapi": response.get("openapi"),
|
||||
"openapi": openapi_data,
|
||||
"info": response.get("info"),
|
||||
"specs": response.get("specs"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user