mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'dev' into pyodide-files
This commit is contained in:
@@ -989,6 +989,26 @@ DEFAULT_USER_ROLE = PersistentConfig(
|
||||
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
||||
)
|
||||
|
||||
PENDING_USER_OVERLAY_TITLE = PersistentConfig(
|
||||
"PENDING_USER_OVERLAY_TITLE",
|
||||
"ui.pending_user_overlay_title",
|
||||
os.environ.get("PENDING_USER_OVERLAY_TITLE", ""),
|
||||
)
|
||||
|
||||
PENDING_USER_OVERLAY_CONTENT = PersistentConfig(
|
||||
"PENDING_USER_OVERLAY_CONTENT",
|
||||
"ui.pending_user_overlay_content",
|
||||
os.environ.get("PENDING_USER_OVERLAY_CONTENT", ""),
|
||||
)
|
||||
|
||||
|
||||
RESPONSE_WATERMARK = PersistentConfig(
|
||||
"RESPONSE_WATERMARK",
|
||||
"ui.watermark",
|
||||
os.environ.get("RESPONSE_WATERMARK", ""),
|
||||
)
|
||||
|
||||
|
||||
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
|
||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
|
||||
== "true"
|
||||
@@ -1731,6 +1751,9 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
|
||||
QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true"
|
||||
QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true"
|
||||
QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334"))
|
||||
ENABLE_QDRANT_MULTITENANCY_MODE = (
|
||||
os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
|
||||
@@ -1825,6 +1848,18 @@ CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
|
||||
)
|
||||
|
||||
EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig(
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL",
|
||||
"rag.external_document_loader_url",
|
||||
os.environ.get("EXTERNAL_DOCUMENT_LOADER_URL", ""),
|
||||
)
|
||||
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY = PersistentConfig(
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY",
|
||||
"rag.external_document_loader_api_key",
|
||||
os.environ.get("EXTERNAL_DOCUMENT_LOADER_API_KEY", ""),
|
||||
)
|
||||
|
||||
TIKA_SERVER_URL = PersistentConfig(
|
||||
"TIKA_SERVER_URL",
|
||||
"rag.tika_server_url",
|
||||
@@ -1849,6 +1884,12 @@ DOCLING_OCR_LANG = PersistentConfig(
|
||||
os.getenv("DOCLING_OCR_LANG", "eng,fra,deu,spa"),
|
||||
)
|
||||
|
||||
DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION",
|
||||
"rag.docling_do_picture_description",
|
||||
os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
@@ -1920,6 +1961,16 @@ RAG_FILE_MAX_SIZE = PersistentConfig(
|
||||
),
|
||||
)
|
||||
|
||||
RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig(
|
||||
"RAG_ALLOWED_FILE_EXTENSIONS",
|
||||
"rag.file.allowed_extensions",
|
||||
[
|
||||
ext.strip()
|
||||
for ext in os.environ.get("RAG_ALLOWED_FILE_EXTENSIONS", "").split(",")
|
||||
if ext.strip()
|
||||
],
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_ENGINE = PersistentConfig(
|
||||
"RAG_EMBEDDING_ENGINE",
|
||||
"rag.embedding_engine",
|
||||
@@ -2126,6 +2177,12 @@ BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig(
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER",
|
||||
"rag.web.search.bypass_web_loader",
|
||||
os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true",
|
||||
)
|
||||
|
||||
WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"WEB_SEARCH_RESULT_COUNT",
|
||||
"rag.web.search.result_count",
|
||||
@@ -2151,6 +2208,7 @@ WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
|
||||
WEB_LOADER_ENGINE = PersistentConfig(
|
||||
"WEB_LOADER_ENGINE",
|
||||
"rag.web.loader.engine",
|
||||
@@ -2839,6 +2897,12 @@ LDAP_CA_CERT_FILE = PersistentConfig(
|
||||
os.environ.get("LDAP_CA_CERT_FILE", ""),
|
||||
)
|
||||
|
||||
LDAP_VALIDATE_CERT = PersistentConfig(
|
||||
"LDAP_VALIDATE_CERT",
|
||||
"ldap.server.validate_cert",
|
||||
os.environ.get("LDAP_VALIDATE_CERT", "True").lower() == "true",
|
||||
)
|
||||
|
||||
LDAP_CIPHERS = PersistentConfig(
|
||||
"LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL")
|
||||
)
|
||||
|
||||
@@ -54,11 +54,8 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in request.app.state.FUNCTIONS:
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = request.app.state.FUNCTIONS[pipe_id]
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
|
||||
@@ -43,7 +43,7 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url, unquote_password=True)
|
||||
db = connect(db_url, unquote_user=True, unquote_password=True)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
@@ -51,7 +51,7 @@ def register_connection(db_url):
|
||||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_password=True)
|
||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(**connection)
|
||||
|
||||
@@ -197,6 +197,7 @@ from open_webui.config import (
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_ALLOWED_FILE_EXTENSIONS,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
@@ -206,10 +207,13 @@ from open_webui.config import (
|
||||
CHUNK_OVERLAP,
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL,
|
||||
DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
MISTRAL_OCR_API_KEY,
|
||||
@@ -224,6 +228,7 @@ from open_webui.config import (
|
||||
ENABLE_WEB_SEARCH,
|
||||
WEB_SEARCH_ENGINE,
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
WEB_SEARCH_RESULT_COUNT,
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
WEB_SEARCH_TRUST_ENV,
|
||||
@@ -291,6 +296,8 @@ from open_webui.config import (
|
||||
ENABLE_EVALUATION_ARENA_MODELS,
|
||||
USER_PERMISSIONS,
|
||||
DEFAULT_USER_ROLE,
|
||||
PENDING_USER_OVERLAY_CONTENT,
|
||||
PENDING_USER_OVERLAY_TITLE,
|
||||
DEFAULT_PROMPT_SUGGESTIONS,
|
||||
DEFAULT_MODELS,
|
||||
DEFAULT_ARENA_MODEL,
|
||||
@@ -317,6 +324,7 @@ from open_webui.config import (
|
||||
LDAP_APP_PASSWORD,
|
||||
LDAP_USE_TLS,
|
||||
LDAP_CA_CERT_FILE,
|
||||
LDAP_VALIDATE_CERT,
|
||||
LDAP_CIPHERS,
|
||||
# Misc
|
||||
ENV,
|
||||
@@ -327,6 +335,7 @@ from open_webui.config import (
|
||||
DEFAULT_LOCALE,
|
||||
OAUTH_PROVIDERS,
|
||||
WEBUI_URL,
|
||||
RESPONSE_WATERMARK,
|
||||
# Admin
|
||||
ENABLE_ADMIN_CHAT_ACCESS,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
@@ -373,6 +382,7 @@ from open_webui.env import (
|
||||
OFFLINE_MODE,
|
||||
ENABLE_OTEL,
|
||||
EXTERNAL_PWA_MANIFEST_URL,
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
|
||||
@@ -573,6 +583,11 @@ app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
|
||||
app.state.config.PENDING_USER_OVERLAY_CONTENT = PENDING_USER_OVERLAY_CONTENT
|
||||
app.state.config.PENDING_USER_OVERLAY_TITLE = PENDING_USER_OVERLAY_TITLE
|
||||
|
||||
app.state.config.RESPONSE_WATERMARK = RESPONSE_WATERMARK
|
||||
|
||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.config.BANNERS = WEBUI_BANNERS
|
||||
@@ -609,6 +624,7 @@ app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
|
||||
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
|
||||
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
|
||||
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
|
||||
app.state.config.LDAP_VALIDATE_CERT = LDAP_VALIDATE_CERT
|
||||
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
||||
|
||||
|
||||
@@ -631,6 +647,7 @@ app.state.FUNCTIONS = {}
|
||||
app.state.config.TOP_K = RAG_TOP_K
|
||||
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
|
||||
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS
|
||||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
|
||||
@@ -641,10 +658,13 @@ app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL
|
||||
app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL
|
||||
app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE
|
||||
app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG
|
||||
app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY
|
||||
@@ -688,6 +708,7 @@ app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV
|
||||
app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER
|
||||
|
||||
app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION
|
||||
app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION
|
||||
@@ -1167,11 +1188,12 @@ async def chat_completion(
|
||||
"chat_id": form_data.pop("chat_id", None),
|
||||
"message_id": form_data.pop("id", None),
|
||||
"session_id": form_data.pop("session_id", None),
|
||||
"filter_ids": form_data.pop("filter_ids", []),
|
||||
"tool_ids": form_data.get("tool_ids", None),
|
||||
"tool_servers": form_data.pop("tool_servers", None),
|
||||
"files": form_data.get("files", None),
|
||||
"features": form_data.get("features", None),
|
||||
"variables": form_data.get("variables", None),
|
||||
"features": form_data.get("features", {}),
|
||||
"variables": form_data.get("variables", {}),
|
||||
"model": model,
|
||||
"direct": model_item.get("direct", False),
|
||||
**(
|
||||
@@ -1395,6 +1417,11 @@ async def get_app_config(request: Request):
|
||||
"sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value,
|
||||
"sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value,
|
||||
},
|
||||
"ui": {
|
||||
"pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"response_watermark": app.state.config.RESPONSE_WATERMARK,
|
||||
},
|
||||
"license_metadata": app.state.LICENSE_METADATA,
|
||||
**(
|
||||
{
|
||||
@@ -1446,7 +1473,8 @@ async def get_app_latest_release_version(user=Depends(get_verified_user)):
|
||||
timeout = aiohttp.ClientTimeout(total=1)
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
||||
"https://api.github.com/repos/open-webui/open-webui/releases/latest",
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
@@ -129,12 +129,16 @@ class AuthsTable:
|
||||
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_db() as db:
|
||||
auth = db.query(Auth).filter_by(email=email, active=True).first()
|
||||
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
|
||||
if auth:
|
||||
if verify_password(password, auth.password):
|
||||
user = Users.get_user_by_id(auth.id)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
58
backend/open_webui/retrieval/loaders/external_document.py
Normal file
58
backend/open_webui/retrieval/loaders/external_document.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalDocumentLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
url: str,
|
||||
api_key: str,
|
||||
mime_type=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
data = f.read()
|
||||
|
||||
headers = {}
|
||||
if self.mime_type is not None:
|
||||
headers["Content-Type"] = self.mime_type
|
||||
|
||||
if self.api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
url = self.url
|
||||
if url.endswith("/"):
|
||||
url = url[:-1]
|
||||
|
||||
r = requests.put(f"{url}/process", data=data, headers=headers)
|
||||
|
||||
if r.ok:
|
||||
res = r.json()
|
||||
|
||||
if res:
|
||||
return [
|
||||
Document(
|
||||
page_content=res.get("page_content"),
|
||||
metadata=res.get("metadata"),
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise Exception("Error loading document: No content returned")
|
||||
else:
|
||||
raise Exception(f"Error loading document: {r.status_code} {r.text}")
|
||||
@@ -10,7 +10,7 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalLoader(BaseLoader):
|
||||
class ExternalWebLoader(BaseLoader):
|
||||
def __init__(
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
@@ -32,7 +32,7 @@ class ExternalLoader(BaseLoader):
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
},
|
||||
json={
|
||||
@@ -21,6 +21,8 @@ from langchain_community.document_loaders import (
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
@@ -126,14 +128,12 @@ class TikaLoader:
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(
|
||||
self, url, file_path=None, mime_type=None, ocr_engine=None, ocr_lang=None
|
||||
):
|
||||
def __init__(self, url, file_path=None, mime_type=None, params=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
self.ocr_engine = ocr_engine
|
||||
self.ocr_lang = ocr_lang
|
||||
|
||||
self.params = params or {}
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
@@ -150,11 +150,19 @@ class DoclingLoader:
|
||||
"table_mode": "accurate",
|
||||
}
|
||||
|
||||
if self.ocr_engine and self.ocr_lang:
|
||||
params["ocr_engine"] = self.ocr_engine
|
||||
params["ocr_lang"] = [
|
||||
lang.strip() for lang in self.ocr_lang.split(",") if lang.strip()
|
||||
]
|
||||
if self.params:
|
||||
if self.params.get("do_picture_classification"):
|
||||
params["do_picture_classification"] = self.params.get(
|
||||
"do_picture_classification"
|
||||
)
|
||||
|
||||
if self.params.get("ocr_engine") and self.params.get("ocr_lang"):
|
||||
params["ocr_engine"] = self.params.get("ocr_engine")
|
||||
params["ocr_lang"] = [
|
||||
lang.strip()
|
||||
for lang in self.params.get("ocr_lang").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
endpoint = f"{self.url}/v1alpha/convert/file"
|
||||
r = requests.post(endpoint, files=files, data=params)
|
||||
@@ -207,7 +215,18 @@ class Loader:
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
|
||||
if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
||||
):
|
||||
loader = ExternalDocumentLoader(
|
||||
file_path=file_path,
|
||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
@@ -225,8 +244,13 @@ class Loader:
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
ocr_engine=self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
ocr_lang=self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
params={
|
||||
"ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"),
|
||||
"ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"),
|
||||
"do_picture_classification": self.kwargs.get(
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION"
|
||||
),
|
||||
},
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
@@ -258,6 +282,15 @@ class Loader:
|
||||
loader = MistralLoader(
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
||||
)
|
||||
elif (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), file_path=file_path
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
@@ -14,18 +18,29 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
class MistralLoader:
|
||||
"""
|
||||
Enhanced Mistral OCR loader with both sync and async support.
|
||||
Loads documents by processing them through the Mistral OCR API.
|
||||
"""
|
||||
|
||||
BASE_API_URL = "https://api.mistral.ai/v1"
|
||||
|
||||
def __init__(self, api_key: str, file_path: str):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
file_path: str,
|
||||
timeout: int = 300, # 5 minutes default
|
||||
max_retries: int = 3,
|
||||
enable_debug_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the loader.
|
||||
Initializes the loader with enhanced features.
|
||||
|
||||
Args:
|
||||
api_key: Your Mistral API key.
|
||||
file_path: The local path to the PDF file to process.
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
@@ -34,7 +49,23 @@ class MistralLoader:
|
||||
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.debug = enable_debug_logging
|
||||
|
||||
# Pre-compute file info for performance
|
||||
self.file_name = os.path.basename(file_path)
|
||||
self.file_size = os.path.getsize(file_path)
|
||||
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0",
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
"""Conditional debug logging for performance."""
|
||||
if self.debug:
|
||||
log.debug(message, *args)
|
||||
|
||||
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
|
||||
"""Checks response status and returns JSON content."""
|
||||
@@ -54,24 +85,89 @@ class MistralLoader:
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
raise # Re-raise after logging
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
raise
|
||||
|
||||
def _retry_request_sync(self, request_func, *args, **kwargs):
|
||||
"""Synchronous retry logic with exponential backoff."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return request_func(*args, **kwargs)
|
||||
except (requests.exceptions.RequestException, Exception) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
async def _retry_request_async(self, request_func, *args, **kwargs):
|
||||
"""Async retry logic with exponential backoff."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
return await request_func(*args, **kwargs)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
|
||||
wait_time = (2**attempt) + 0.5
|
||||
log.warning(
|
||||
f"Request failed (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
def _upload_file(self) -> str:
|
||||
"""Uploads the file to Mistral for OCR processing."""
|
||||
"""Uploads the file to Mistral for OCR processing (sync version)."""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
file_name = os.path.basename(self.file_path)
|
||||
|
||||
try:
|
||||
def upload_request():
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
|
||||
upload_headers = self.headers.copy() # Avoid modifying self.headers
|
||||
|
||||
response = requests.post(
|
||||
url, headers=upload_headers, files=files, data=data
|
||||
url,
|
||||
headers=self.headers,
|
||||
files=files,
|
||||
data=data,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
response_data = self._handle_response(response)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
@@ -81,16 +177,66 @@ class MistralLoader:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.BASE_API_URL}/files"
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
data=writer,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file."""
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=signed_url_headers, params=params)
|
||||
response_data = self._handle_response(response)
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
@@ -100,8 +246,36 @@ class MistralLoader:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing."""
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
ocr_headers = {
|
||||
@@ -118,43 +292,198 @@ class MistralLoader:
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=ocr_headers, json=payload)
|
||||
ocr_response = self._handle_response(response)
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
log.debug("OCR response: %s", ocr_response)
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.BASE_API_URL}/ocr"
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
|
||||
return ocr_response
|
||||
|
||||
return await self._retry_request_async(ocr_request)
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage."""
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.BASE_API_URL}/files/{file_id}"
|
||||
# No specific Accept header needed, default or Authorization is usually sufficient
|
||||
|
||||
try:
|
||||
response = requests.delete(url, headers=self.headers)
|
||||
delete_response = self._handle_response(
|
||||
response
|
||||
) # Check status, ignore response body unless needed
|
||||
log.info(
|
||||
f"File deleted successfully: {delete_response}"
|
||||
) # Log the response if available
|
||||
response = requests.delete(url, headers=self.headers, timeout=30)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
# Depending on requirements, you might choose to raise the error here
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
async with session.delete(
|
||||
url=f"{self.BASE_API_URL}/files/{file_id}",
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=30
|
||||
), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
"""Context manager for HTTP session with optimized settings."""
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=10, # Total connection limit
|
||||
limit_per_host=5, # Per-host connection limit
|
||||
ttl_dns_cache=300, # DNS cache TTL
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=30,
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found", metadata={"error": "no_pages"}
|
||||
)
|
||||
]
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
skipped_pages = 0
|
||||
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is not None and page_index is not None:
|
||||
# Clean up content efficiently
|
||||
cleaned_content = (
|
||||
page_content.strip()
|
||||
if isinstance(page_content, str)
|
||||
else str(page_content)
|
||||
)
|
||||
|
||||
if cleaned_content: # Only add non-empty pages
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index
|
||||
+ 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
else:
|
||||
skipped_pages += 1
|
||||
self._debug_log(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
metadata={"error": "no_valid_pages", "total_pages": total_pages},
|
||||
)
|
||||
]
|
||||
|
||||
return documents
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Executes the full OCR workflow: upload, get URL, process OCR, delete file.
|
||||
Synchronous version for backward compatibility.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. Upload file
|
||||
file_id = self._upload_file()
|
||||
@@ -166,53 +495,30 @@ class MistralLoader:
|
||||
ocr_response = self._process_ocr(signed_url)
|
||||
|
||||
# 4. Process results
|
||||
pages_data = ocr_response.get("pages")
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
return [Document(page_content="No text content found", metadata={})]
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
documents = []
|
||||
total_pages = len(pages_data)
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
|
||||
if page_content is not None and page_index is not None:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=page_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index
|
||||
+ 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
# Add other relevant metadata from page_data if available/needed
|
||||
# e.g., page_data.get('width'), page_data.get('height')
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}"
|
||||
)
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found in valid pages", metadata={}
|
||||
)
|
||||
]
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"An error occurred during the loading process: {e}")
|
||||
# Return an empty list or a specific error document on failure
|
||||
return [Document(page_content=f"Error during processing: {e}", metadata={})]
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Delete file (attempt even if prior steps failed after upload)
|
||||
if file_id:
|
||||
@@ -223,3 +529,105 @@ class MistralLoader:
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
Asynchronous OCR workflow execution with optimized performance.
|
||||
|
||||
Returns:
|
||||
A list of Document objects, one for each page processed.
|
||||
"""
|
||||
file_id = None
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
# 1. Upload file with streaming
|
||||
file_id = await self._upload_file_async(session)
|
||||
|
||||
# 2. Get signed URL
|
||||
signed_url = await self._get_signed_url_async(session, file_id)
|
||||
|
||||
# 3. Process OCR
|
||||
ocr_response = await self._process_ocr_async(session, signed_url)
|
||||
|
||||
# 4. Process results
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
finally:
|
||||
# 5. Cleanup - always attempt file deletion
|
||||
if file_id:
|
||||
try:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
Process multiple files concurrently for maximum performance.
|
||||
|
||||
Args:
|
||||
loaders: List of MistralLoader instances
|
||||
|
||||
Returns:
|
||||
List of document lists, one for each loader
|
||||
"""
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(f"Starting concurrent processing of {len(loaders)} files")
|
||||
start_time = time.time()
|
||||
|
||||
# Process all files concurrently
|
||||
tasks = [loader.load_async() for loader in loaders]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle any exceptions in results
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s, produced {total_docs} total documents"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
8
backend/open_webui/retrieval/models/base_reranker.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
@abstractmethod
|
||||
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||
pass
|
||||
@@ -7,11 +7,13 @@ from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ColBERT:
|
||||
class ColBERT(BaseReranker):
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -3,12 +3,14 @@ import requests
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class ExternalReranker:
|
||||
class ExternalReranker(BaseReranker):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
elif VECTOR_DB == "elasticsearch":
|
||||
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = ElasticsearchClient()
|
||||
elif VECTOR_DB == "pinecone":
|
||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||
|
||||
VECTOR_DB_CLIENT = PineconeClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
@@ -1,13 +1,12 @@
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
import time # for measuring elapsed time
|
||||
from pinecone import ServerlessSpec
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
import asyncio # for async upserts
|
||||
import functools # for partial binding in async tasks
|
||||
|
||||
import concurrent.futures # for parallel batch upserts
|
||||
from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
@@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase):
|
||||
self.metric = PINECONE_METRIC
|
||||
self.cloud = PINECONE_CLOUD
|
||||
|
||||
# Initialize Pinecone gRPC client for improved performance
|
||||
self.client = PineconeGRPC(
|
||||
api_key=self.api_key, environment=self.environment, cloud=self.cloud
|
||||
)
|
||||
# Initialize Pinecone client for improved performance
|
||||
self.client = Pinecone(api_key=self.api_key)
|
||||
|
||||
# Persistent executor for batch operations
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
@@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase):
|
||||
metadatas = []
|
||||
|
||||
for match in matches:
|
||||
metadata = match.get("metadata", {})
|
||||
ids.append(match["id"])
|
||||
metadata = getattr(match, "metadata", {}) or {}
|
||||
ids.append(match.id if hasattr(match, "id") else match["id"])
|
||||
documents.append(metadata.get("text", ""))
|
||||
metadatas.append(metadata)
|
||||
|
||||
@@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
include_metadata=False,
|
||||
)
|
||||
return len(response.matches) > 0
|
||||
matches = getattr(response, "matches", []) or []
|
||||
return len(matches) > 0
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error checking collection '{collection_name_with_prefix}': {e}"
|
||||
@@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase):
|
||||
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Perform a streaming upsert over gRPC for performance testing."""
|
||||
if not items:
|
||||
log.warning("No items to upsert via streaming")
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Open a streaming upsert channel
|
||||
stream = self.index.streaming_upsert()
|
||||
try:
|
||||
for point in points:
|
||||
# send each point over the stream
|
||||
stream.send(point)
|
||||
# close the stream to finalize
|
||||
stream.close()
|
||||
log.info(
|
||||
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error during streaming upsert: {e}")
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
@@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
if not query_response.matches:
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
if not matches:
|
||||
# Return empty result if no matches
|
||||
return SearchResult(
|
||||
ids=[[]],
|
||||
@@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase):
|
||||
)
|
||||
|
||||
# Convert to GetResult format
|
||||
get_result = self._result_to_get_result(query_response.matches)
|
||||
get_result = self._result_to_get_result(matches)
|
||||
|
||||
# Calculate normalized distances based on metric
|
||||
distances = [
|
||||
[
|
||||
self._normalize_distance(match.score)
|
||||
for match in query_response.matches
|
||||
self._normalize_distance(getattr(match, "score", 0.0))
|
||||
for match in matches
|
||||
]
|
||||
]
|
||||
|
||||
@@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase):
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {e}")
|
||||
@@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase):
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
)
|
||||
|
||||
return self._result_to_get_result(query_response.matches)
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting collection '{collection_name}': {e}")
|
||||
@@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase):
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
"""Shut down the gRPC channel and thread pool."""
|
||||
"""Shut down resources."""
|
||||
try:
|
||||
self.client.close()
|
||||
log.info("Pinecone gRPC channel closed.")
|
||||
# The new Pinecone client doesn't need explicit closing
|
||||
pass
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to close Pinecone gRPC channel: {e}")
|
||||
log.warning(f"Failed to clean up Pinecone resources: {e}")
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
712
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
Normal file
712
backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py
Normal file
@@ -0,0 +1,712 @@
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import grpc
|
||||
from open_webui.config import (
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_URI,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
class QdrantClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.QDRANT_URI = QDRANT_URI
|
||||
self.QDRANT_API_KEY = QDRANT_API_KEY
|
||||
self.QDRANT_ON_DISK = QDRANT_ON_DISK
|
||||
self.PREFER_GRPC = QDRANT_PREFER_GRPC
|
||||
self.GRPC_PORT = QDRANT_GRPC_PORT
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
self.client = None
|
||||
return
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
host = parsed.hostname or self.QDRANT_URI
|
||||
http_port = parsed.port or 6333 # default REST port
|
||||
|
||||
if self.PREFER_GRPC:
|
||||
self.client = Qclient(
|
||||
host=host,
|
||||
port=http_port,
|
||||
grpc_port=self.GRPC_PORT,
|
||||
prefer_grpc=self.PREFER_GRPC,
|
||||
api_key=self.QDRANT_API_KEY,
|
||||
)
|
||||
else:
|
||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Maps the traditional collection name to multi-tenant collection and tenant ID.
|
||||
|
||||
Returns:
|
||||
tuple: (collection_name, tenant_id)
|
||||
"""
|
||||
# Check for user memory collections
|
||||
tenant_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
return self.MEMORY_COLLECTION, tenant_id
|
||||
|
||||
# Check for file collections
|
||||
elif collection_name.startswith("file-"):
|
||||
return self.FILE_COLLECTION, tenant_id
|
||||
|
||||
# Check for web search collections
|
||||
elif collection_name.startswith("web-search-"):
|
||||
return self.WEB_SEARCH_COLLECTION, tenant_id
|
||||
|
||||
# Handle hash-based collections (YouTube and web URLs)
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
return self.HASH_BASED_COLLECTION, tenant_id
|
||||
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, tenant_id
|
||||
|
||||
def _extract_error_message(self, exception):
|
||||
"""
|
||||
Extract error message from either HTTP or gRPC exceptions
|
||||
|
||||
Returns:
|
||||
tuple: (status_code, error_message)
|
||||
"""
|
||||
# Check if it's an HTTP exception
|
||||
if isinstance(exception, UnexpectedResponse):
|
||||
try:
|
||||
error_data = exception.structured()
|
||||
error_msg = error_data.get("status", {}).get("error", "")
|
||||
return exception.status_code, error_msg
|
||||
except Exception as inner_e:
|
||||
log.error(f"Failed to parse HTTP error: {inner_e}")
|
||||
return exception.status_code, str(exception)
|
||||
|
||||
# Check if it's a gRPC exception
|
||||
elif isinstance(exception, grpc.RpcError):
|
||||
# Extract status code from gRPC error
|
||||
status_code = None
|
||||
if hasattr(exception, "code") and callable(exception.code):
|
||||
status_code = exception.code().value[0]
|
||||
|
||||
# Extract error message
|
||||
error_msg = str(exception)
|
||||
if "details =" in error_msg:
|
||||
# Parse the details line which contains the actual error message
|
||||
try:
|
||||
details_line = [
|
||||
line.strip()
|
||||
for line in error_msg.split("\n")
|
||||
if "details =" in line
|
||||
][0]
|
||||
error_msg = details_line.split("details =")[1].strip(' "')
|
||||
except (IndexError, AttributeError):
|
||||
# Fall back to full message if parsing fails
|
||||
pass
|
||||
|
||||
return status_code, error_msg
|
||||
|
||||
# For any other type of exception
|
||||
return None, str(exception)
|
||||
|
||||
def _is_collection_not_found_error(self, exception):
|
||||
"""
|
||||
Check if the exception is due to collection not found, supporting both HTTP and gRPC
|
||||
"""
|
||||
status_code, error_msg = self._extract_error_message(exception)
|
||||
|
||||
# HTTP error (404)
|
||||
if (
|
||||
status_code == 404
|
||||
and "Collection" in error_msg
|
||||
and "doesn't exist" in error_msg
|
||||
):
|
||||
return True
|
||||
|
||||
# gRPC error (NOT_FOUND status)
|
||||
if (
|
||||
isinstance(exception, grpc.RpcError)
|
||||
and exception.code() == grpc.StatusCode.NOT_FOUND
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_dimension_mismatch_error(self, exception):
|
||||
"""
|
||||
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
|
||||
"""
|
||||
status_code, error_msg = self._extract_error_message(exception)
|
||||
|
||||
# Common patterns in both HTTP and gRPC
|
||||
return (
|
||||
"Vector dimension error" in error_msg
|
||||
or "dimensions mismatch" in error_msg
|
||||
or "invalid vector size" in error_msg
|
||||
)
|
||||
|
||||
def _create_multi_tenant_collection_if_not_exists(
|
||||
self, mt_collection_name: str, dimension: int = 384
|
||||
):
|
||||
"""
|
||||
Creates a collection with multi-tenancy configuration if it doesn't exist.
|
||||
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
|
||||
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
|
||||
"""
|
||||
try:
|
||||
# Try to create the collection directly - will fail if it already exists
|
||||
self.client.create_collection(
|
||||
collection_name=mt_collection_name,
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension,
|
||||
distance=models.Distance.COSINE,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
hnsw_config=models.HnswConfigDiff(
|
||||
payload_m=16, # Enable per-tenant indexing
|
||||
m=0,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
)
|
||||
|
||||
# Create tenant ID payload index
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
field_name="tenant_id",
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=True,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
wait=True,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||
)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
# Check for the specific error indicating collection already exists
|
||||
status_code, error_msg = self._extract_error_message(e)
|
||||
|
||||
# HTTP status code 409 or gRPC ALREADY_EXISTS
|
||||
if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
|
||||
isinstance(e, grpc.RpcError)
|
||||
and e.code() == grpc.StatusCode.ALREADY_EXISTS
|
||||
):
|
||||
if "already exists" in error_msg:
|
||||
log.debug(f"Collection {mt_collection_name} already exists")
|
||||
return
|
||||
# If it's not an already exists error, re-raise
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _create_points(self, items: list[VectorItem], tenant_id: str):
|
||||
"""
|
||||
Create point structs from vector items with tenant ID.
|
||||
"""
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Check if a logical collection exists by checking for any points with the tenant ID.
|
||||
"""
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Try directly querying - most of the time collection should exist
|
||||
response = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=models.Filter(must=[tenant_filter]),
|
||||
limit=1,
|
||||
)
|
||||
|
||||
# Collection exists with this tenant ID if there are points
|
||||
return len(response.points) > 0
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(f"Collection {mt_collection} doesn't exist")
|
||||
return False
|
||||
else:
|
||||
# For other API errors, log and return False
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error: {error_msg}")
|
||||
return False
|
||||
except Exception as e:
|
||||
# For any other errors, log and return False
|
||||
log.debug(f"Error checking collection {mt_collection}: {e}")
|
||||
return False
|
||||
|
||||
def delete(
|
||||
self,
|
||||
collection_name: str,
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Delete vectors by ID or filter from a collection with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
must_conditions = [tenant_filter]
|
||||
should_conditions = []
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
should_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
)
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
must_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to delete directly - most of the time collection should exist
|
||||
update_result = self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=must_conditions, should=should_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
return update_result
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, nothing to delete"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, re-raise
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for the nearest neighbor items based on the vectors with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get the vector dimension from the query vector
|
||||
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
|
||||
|
||||
try:
|
||||
# Try the search operation directly - most of the time collection should exist
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
# Ensure vector dimensions match the collection
|
||||
collection_dim = self.client.get_collection(
|
||||
mt_collection
|
||||
).config.params.vectors.size
|
||||
|
||||
if collection_dim != dimension:
|
||||
if collection_dim < dimension:
|
||||
vectors = [vector[:collection_dim] for vector in vectors]
|
||||
else:
|
||||
vectors = [
|
||||
vector + [0] * (collection_dim - dimension)
|
||||
for vector in vectors
|
||||
]
|
||||
|
||||
# Search with tenant filter
|
||||
prefetch_query = models.Prefetch(
|
||||
filter=models.Filter(must=[tenant_filter]),
|
||||
limit=NO_LIMIT,
|
||||
)
|
||||
query_response = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query=vectors[0],
|
||||
prefetch=prefetch_query,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
get_result = self._result_to_get_result(query_response.points)
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
# qdrant distance is [-1, 1], normalize to [0, 1]
|
||||
distances=[
|
||||
[(point.score + 1.0) / 2.0 for point in query_response.points]
|
||||
],
|
||||
)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, search returns None"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during search: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and return None
|
||||
log.exception(f"Error searching collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
"""
|
||||
Query points with filters and tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Set default limit if not provided
|
||||
if limit is None:
|
||||
limit = NO_LIMIT
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
# Create metadata filters
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
# Combine tenant filter with metadata filters
|
||||
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
||||
|
||||
try:
|
||||
# Try the query directly - most of the time collection should exist
|
||||
points = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=combined_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(points.points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(
|
||||
f"Collection {mt_collection} doesn't exist, query returns None"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during query: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and re-raise
|
||||
log.exception(f"Error querying collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""
|
||||
Get all items in a collection with tenant isolation.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Create tenant filter
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get points directly - most of the time collection should exist
|
||||
points = self.client.query_points(
|
||||
collection_name=mt_collection,
|
||||
query_filter=models.Filter(must=[tenant_filter]),
|
||||
limit=NO_LIMIT,
|
||||
)
|
||||
|
||||
return self._result_to_get_result(points.points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
||||
return None
|
||||
else:
|
||||
# For other API errors, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unexpected Qdrant error during get: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, log and return None
|
||||
log.exception(f"Error getting collection '{collection_name}': {e}")
|
||||
return None
|
||||
|
||||
def _handle_operation_with_error_retry(
|
||||
self, operation_name, mt_collection, points, dimension
|
||||
):
|
||||
"""
|
||||
Private helper to handle common error cases for insert and upsert operations.
|
||||
|
||||
Args:
|
||||
operation_name: 'insert' or 'upsert'
|
||||
mt_collection: The multi-tenant collection name
|
||||
points: The vector points to insert/upsert
|
||||
dimension: The dimension of the vectors
|
||||
|
||||
Returns:
|
||||
The operation result (for upsert) or None (for insert)
|
||||
"""
|
||||
try:
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
except (UnexpectedResponse, grpc.RpcError) as e:
|
||||
# Handle collection not found
|
||||
if self._is_collection_not_found_error(e):
|
||||
log.info(
|
||||
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
|
||||
)
|
||||
# Create collection with correct dimensions from our vectors
|
||||
self._create_multi_tenant_collection_if_not_exists(
|
||||
mt_collection_name=mt_collection, dimension=dimension
|
||||
)
|
||||
# Try operation again - no need for dimension adjustment since we just created with correct dimensions
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
|
||||
# Handle dimension mismatch
|
||||
elif self._is_dimension_mismatch_error(e):
|
||||
# For dimension errors, the collection must exist, so get its configuration
|
||||
mt_collection_info = self.client.get_collection(mt_collection)
|
||||
existing_size = mt_collection_info.config.params.vectors.size
|
||||
|
||||
log.info(
|
||||
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
|
||||
)
|
||||
|
||||
if existing_size < dimension:
|
||||
# Truncate vectors to fit
|
||||
log.info(
|
||||
f"Truncating vectors from {dimension} to {existing_size} dimensions"
|
||||
)
|
||||
points = [
|
||||
PointStruct(
|
||||
id=point.id,
|
||||
vector=point.vector[:existing_size],
|
||||
payload=point.payload,
|
||||
)
|
||||
for point in points
|
||||
]
|
||||
elif existing_size > dimension:
|
||||
# Pad vectors with zeros
|
||||
log.info(
|
||||
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
|
||||
)
|
||||
points = [
|
||||
PointStruct(
|
||||
id=point.id,
|
||||
vector=point.vector
|
||||
+ [0] * (existing_size - len(point.vector)),
|
||||
payload=point.payload,
|
||||
)
|
||||
for point in points
|
||||
]
|
||||
# Try operation again with adjusted dimensions
|
||||
if operation_name == "insert":
|
||||
self.client.upload_points(mt_collection, points)
|
||||
return None
|
||||
else: # upsert
|
||||
return self.client.upsert(mt_collection, points)
|
||||
else:
|
||||
# Not a known error we can handle, log and re-raise
|
||||
_, error_msg = self._extract_error_message(e)
|
||||
log.warning(f"Unhandled Qdrant error: {error_msg}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# For non-Qdrant exceptions, re-raise
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
"""
|
||||
Insert items with tenant ID.
|
||||
"""
|
||||
if not self.client or not items:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get dimensions from the actual vectors
|
||||
dimension = len(items[0]["vector"]) if items else None
|
||||
|
||||
# Create points with tenant ID
|
||||
points = self._create_points(items, tenant_id)
|
||||
|
||||
# Handle the operation with error retry
|
||||
return self._handle_operation_with_error_retry(
|
||||
"insert", mt_collection, points, dimension
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
"""
|
||||
Upsert items with tenant ID.
|
||||
"""
|
||||
if not self.client or not items:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
# Get dimensions from the actual vectors
|
||||
dimension = len(items[0]["vector"]) if items else None
|
||||
|
||||
# Create points with tenant ID
|
||||
points = self._create_points(items, tenant_id)
|
||||
|
||||
# Handle the operation with error retry
|
||||
return self._handle_operation_with_error_retry(
|
||||
"upsert", mt_collection, points, dimension
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the database by deleting all collections.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
collection_names = self.client.get_collections().collections
|
||||
for collection_name in collection_names:
|
||||
if collection_name.name.startswith(self.collection_prefix):
|
||||
self.client.delete_collection(collection_name=collection_name.name)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
if not self.client:
|
||||
return None
|
||||
|
||||
# Map to multi-tenant collection and tenant ID
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
|
||||
tenant_filter = models.FieldCondition(
|
||||
key="tenant_id", match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
|
||||
field_conditions = [tenant_filter]
|
||||
|
||||
update_result = self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
if self.client.get_collection(mt_collection).points_count == 0:
|
||||
self.client.delete_collection(mt_collection)
|
||||
|
||||
return update_result
|
||||
55
backend/open_webui/retrieval/vector/factory.py
Normal file
55
backend/open_webui/retrieval/vector/factory.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from open_webui.retrieval.vector.main import VectorDBBase
|
||||
from open_webui.retrieval.vector.type import VectorType
|
||||
from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE
|
||||
|
||||
|
||||
class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector(vector_type: str) -> VectorDBBase:
|
||||
"""
|
||||
get vector db instance by vector type
|
||||
"""
|
||||
match vector_type:
|
||||
case VectorType.MILVUS:
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
return MilvusClient()
|
||||
case VectorType.QDRANT:
|
||||
if ENABLE_QDRANT_MULTITENANCY_MODE:
|
||||
from open_webui.retrieval.vector.dbs.qdrant_multitenancy import (
|
||||
QdrantClient,
|
||||
)
|
||||
|
||||
return QdrantClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
return QdrantClient()
|
||||
case VectorType.PINECONE:
|
||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||
|
||||
return PineconeClient()
|
||||
case VectorType.OPENSEARCH:
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
return OpenSearchClient()
|
||||
case VectorType.PGVECTOR:
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
return PgvectorClient()
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from open_webui.retrieval.vector.dbs.elasticsearch import (
|
||||
ElasticsearchClient,
|
||||
)
|
||||
|
||||
return ElasticsearchClient()
|
||||
case VectorType.CHROMA:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
return ChromaClient()
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||
|
||||
|
||||
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)
|
||||
11
backend/open_webui/retrieval/vector/type.py
Normal file
11
backend/open_webui/retrieval/vector/type.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class VectorType(StrEnum):
|
||||
MILVUS = "milvus"
|
||||
QDRANT = "qdrant"
|
||||
CHROMA = "chroma"
|
||||
PINECONE = "pinecone"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
OPENSEARCH = "opensearch"
|
||||
PGVECTOR = "pgvector"
|
||||
@@ -42,7 +42,9 @@ def search_searchapi(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -42,7 +42,9 @@ def search_serpapi(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -25,7 +25,7 @@ from langchain_community.document_loaders.firecrawl import FireCrawlLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.retrieval.loaders.external import ExternalLoader
|
||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.config import (
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
@@ -39,7 +39,7 @@ from open_webui.config import (
|
||||
EXTERNAL_WEB_LOADER_URL,
|
||||
EXTERNAL_WEB_LOADER_API_KEY,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -515,7 +515,8 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
kwargs["ssl"] = False
|
||||
|
||||
async with session.get(
|
||||
url, **(self.requests_kwargs | kwargs)
|
||||
url,
|
||||
**(self.requests_kwargs | kwargs),
|
||||
) as response:
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
@@ -628,7 +629,7 @@ def get_web_loader(
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "external":
|
||||
WebLoaderClass = ExternalLoader
|
||||
WebLoaderClass = ExternalWebLoader
|
||||
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
|
||||
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
@@ -17,6 +20,7 @@ from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -38,6 +42,7 @@ from open_webui.config import (
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -49,7 +54,7 @@ from open_webui.env import (
|
||||
router = APIRouter()
|
||||
|
||||
# Constants
|
||||
MAX_FILE_SIZE_MB = 25
|
||||
MAX_FILE_SIZE_MB = 20
|
||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
AZURE_MAX_FILE_SIZE_MB = 200
|
||||
AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
@@ -71,35 +76,50 @@ from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def get_audio_convert_format(file_path):
|
||||
"""Check if the given file needs to be converted to a different format."""
|
||||
def is_audio_conversion_required(file_path):
|
||||
"""
|
||||
Check if the given audio file needs conversion to mp3.
|
||||
"""
|
||||
SUPPORTED_FORMATS = {"flac", "m4a", "mp3", "mp4", "mpeg", "wav", "webm"}
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
log.error(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
info = mediainfo(file_path)
|
||||
codec_name = info.get("codec_name", "").lower()
|
||||
codec_type = info.get("codec_type", "").lower()
|
||||
codec_tag_string = info.get("codec_tag_string", "").lower()
|
||||
|
||||
if codec_name == "aac" and codec_type == "audio" and codec_tag_string == "mp4a":
|
||||
# File is AAC/mp4a audio, recommend mp3 conversion
|
||||
return True
|
||||
|
||||
# If the codec name or file extension is in the supported formats
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
codec_name in SUPPORTED_FORMATS
|
||||
or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS
|
||||
):
|
||||
return "mp4"
|
||||
elif info.get("format_name") == "ogg":
|
||||
return "ogg"
|
||||
return False # Already supported
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error getting audio format: {e}")
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_audio_to_wav(file_path, output_path, conversion_type):
|
||||
"""Convert MP4/OGG audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format=conversion_type)
|
||||
audio.export(output_path, format="wav")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
def convert_audio_to_mp3(file_path):
|
||||
"""Convert audio file to mp3 format."""
|
||||
try:
|
||||
output_path = os.path.splitext(file_path)[0] + ".mp3"
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio.export(output_path, format="mp3")
|
||||
log.info(f"Converted {file_path} to {output_path}")
|
||||
return output_path
|
||||
except Exception as e:
|
||||
log.error(f"Error converting audio file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
@@ -326,6 +346,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -381,6 +402,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -439,6 +461,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
"X-Microsoft-OutputFormat": output_format,
|
||||
},
|
||||
data=data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -507,12 +530,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
log.info(f"transcribe: {file_path}")
|
||||
def transcription_handler(request, file_path, metadata):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
metadata = metadata or {}
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
@@ -524,7 +548,7 @@ def transcribe(request: Request, file_path):
|
||||
file_path,
|
||||
beam_size=5,
|
||||
vad_filter=request.app.state.config.WHISPER_VAD_FILTER,
|
||||
language=WHISPER_LANGUAGE,
|
||||
language=metadata.get("language") or WHISPER_LANGUAGE,
|
||||
)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
@@ -542,19 +566,6 @@ def transcribe(request: Request, file_path):
|
||||
log.debug(data)
|
||||
return data
|
||||
elif request.app.state.config.STT_ENGINE == "openai":
|
||||
convert_format = get_audio_convert_format(file_path)
|
||||
|
||||
if convert_format:
|
||||
ext = convert_format.split(".")[-1]
|
||||
|
||||
os.rename(file_path, file_path.replace(".{ext}", f".{convert_format}"))
|
||||
# Convert unsupported audio file to WAV format
|
||||
convert_audio_to_wav(
|
||||
file_path.replace(".{ext}", f".{convert_format}"),
|
||||
file_path,
|
||||
convert_format,
|
||||
)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
@@ -563,7 +574,14 @@ def transcribe(request: Request, file_path):
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
data={
|
||||
"model": request.app.state.config.STT_MODEL,
|
||||
**(
|
||||
{"language": metadata.get("language")}
|
||||
if metadata.get("language")
|
||||
else {}
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
@@ -771,41 +789,135 @@ def transcribe(request: Request, file_path):
|
||||
)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None):
|
||||
log.info(f"transcribe: {file_path} {metadata}")
|
||||
|
||||
if is_audio_conversion_required(file_path):
|
||||
file_path = convert_audio_to_mp3(file_path)
|
||||
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
# Always produce a list of chunk paths (could be one entry if small)
|
||||
try:
|
||||
chunk_paths = split_audio(file_path, MAX_FILE_SIZE)
|
||||
print(f"Chunk paths: {chunk_paths}")
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
results = []
|
||||
try:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit tasks for each chunk_path
|
||||
futures = [
|
||||
executor.submit(transcription_handler, request, chunk_path, metadata)
|
||||
for chunk_path in chunk_paths
|
||||
]
|
||||
# Gather results as they complete
|
||||
for future in futures:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as transcribe_exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error transcribing chunk: {transcribe_exc}",
|
||||
)
|
||||
finally:
|
||||
# Clean up only the temporary chunks, never the original file
|
||||
for chunk_path in chunk_paths:
|
||||
if chunk_path != file_path and os.path.isfile(chunk_path):
|
||||
try:
|
||||
os.remove(chunk_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"text": " ".join([result["text"] for result in results]),
|
||||
}
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
id = os.path.splitext(os.path.basename(file_path))[
|
||||
0
|
||||
] # Handles names with multiple dots
|
||||
file_dir = os.path.dirname(file_path)
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
||||
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
||||
audio.export(compressed_path, format="opus", bitrate="32k")
|
||||
log.debug(f"Compressed audio to {compressed_path}")
|
||||
|
||||
if (
|
||||
os.path.getsize(compressed_path) > MAX_FILE_SIZE
|
||||
): # Still larger than MAX_FILE_SIZE after compression
|
||||
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
|
||||
compressed_path = os.path.join(file_dir, f"{id}_compressed.mp3")
|
||||
audio.export(compressed_path, format="mp3", bitrate="32k")
|
||||
# log.debug(f"Compressed audio to {compressed_path}") # Uncomment if log is defined
|
||||
|
||||
return compressed_path
|
||||
else:
|
||||
return file_path
|
||||
|
||||
|
||||
def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"):
|
||||
"""
|
||||
Splits audio into chunks not exceeding max_bytes.
|
||||
Returns a list of chunk file paths. If audio fits, returns list with original path.
|
||||
"""
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size <= max_bytes:
|
||||
return [file_path] # Nothing to split
|
||||
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
duration_ms = len(audio)
|
||||
orig_size = file_size
|
||||
|
||||
approx_chunk_ms = max(int(duration_ms * (max_bytes / orig_size)) - 1000, 1000)
|
||||
chunks = []
|
||||
start = 0
|
||||
i = 0
|
||||
|
||||
base, _ = os.path.splitext(file_path)
|
||||
|
||||
while start < duration_ms:
|
||||
end = min(start + approx_chunk_ms, duration_ms)
|
||||
chunk = audio[start:end]
|
||||
chunk_path = f"{base}_chunk_{i}.{format}"
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
# Reduce chunk duration if still too large
|
||||
while os.path.getsize(chunk_path) > max_bytes and (end - start) > 5000:
|
||||
end = start + ((end - start) // 2)
|
||||
chunk = audio[start:end]
|
||||
chunk.export(chunk_path, format=format, bitrate=bitrate)
|
||||
|
||||
if os.path.getsize(chunk_path) > max_bytes:
|
||||
os.remove(chunk_path)
|
||||
raise Exception("Audio chunk cannot be reduced below max file size.")
|
||||
|
||||
chunks.append(chunk_path)
|
||||
start = end
|
||||
i += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@router.post("/transcriptions")
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
supported_filetypes = (
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/x-m4a",
|
||||
"audio/webm",
|
||||
)
|
||||
|
||||
if not file.content_type.startswith(supported_filetypes):
|
||||
SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types!
|
||||
if not (
|
||||
file.content_type.startswith("audio/")
|
||||
or file.content_type in SUPPORTED_CONTENT_TYPES
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
@@ -826,19 +938,18 @@ def transcription(
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
metadata = None
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
if language:
|
||||
metadata = {"language": language}
|
||||
|
||||
result = transcribe(request, file_path, metadata)
|
||||
|
||||
return {
|
||||
**result,
|
||||
"filename": os.path.basename(file_path),
|
||||
}
|
||||
|
||||
data = transcribe(request, file_path)
|
||||
file_path = file_path.split("/")[-1]
|
||||
return {**data, "filename": file_path}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from fastapi.responses import RedirectResponse, Response, JSONResponse
|
||||
from open_webui.config import OPENID_PROVIDER_URL, ENABLE_OAUTH_SIGNUP, ENABLE_LDAP
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -51,7 +51,7 @@ from open_webui.utils.access_control import get_permissions
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||
from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS
|
||||
|
||||
if ENABLE_LDAP.value:
|
||||
from ldap3 import Server, Connection, NONE, Tls
|
||||
@@ -186,6 +186,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
||||
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
||||
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
||||
LDAP_VALIDATE_CERT = (
|
||||
CERT_REQUIRED if request.app.state.config.LDAP_VALIDATE_CERT else CERT_NONE
|
||||
)
|
||||
LDAP_CIPHERS = (
|
||||
request.app.state.config.LDAP_CIPHERS
|
||||
if request.app.state.config.LDAP_CIPHERS
|
||||
@@ -197,7 +200,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
|
||||
try:
|
||||
tls = Tls(
|
||||
validate=CERT_REQUIRED,
|
||||
validate=LDAP_VALIDATE_CERT,
|
||||
version=PROTOCOL_TLS,
|
||||
ca_certs_file=LDAP_CA_CERT_FILE,
|
||||
ciphers=LDAP_CIPHERS,
|
||||
@@ -478,10 +481,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||
if len(form_data.password.encode("utf-8")) > 72:
|
||||
raise HTTPException(
|
||||
@@ -541,6 +540,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
if user_count == 0:
|
||||
# Disable signup after the first user is created
|
||||
request.app.state.config.ENABLE_SIGNUP = False
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
@@ -574,9 +577,14 @@ async def signout(request: Request, response: Response):
|
||||
logout_url = openid_data.get("end_session_endpoint")
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
},
|
||||
headers=response.headers,
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -591,12 +599,18 @@ async def signout(request: Request, response: Response):
|
||||
)
|
||||
|
||||
if WEBUI_AUTH_SIGNOUT_REDIRECT_URL:
|
||||
return RedirectResponse(
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": True,
|
||||
"redirect_url": WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
},
|
||||
headers=response.headers,
|
||||
url=WEBUI_AUTH_SIGNOUT_REDIRECT_URL,
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
return JSONResponse(
|
||||
status_code=200, content={"status": True}, headers=response.headers
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
@@ -696,6 +710,9 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
@@ -713,6 +730,9 @@ class AdminConfig(BaseModel):
|
||||
ENABLE_CHANNELS: bool
|
||||
ENABLE_NOTES: bool
|
||||
ENABLE_USER_WEBHOOKS: bool
|
||||
PENDING_USER_OVERLAY_TITLE: Optional[str] = None
|
||||
PENDING_USER_OVERLAY_CONTENT: Optional[str] = None
|
||||
RESPONSE_WATERMARK: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/admin/config")
|
||||
@@ -750,6 +770,15 @@ async def update_admin_config(
|
||||
|
||||
request.app.state.config.ENABLE_USER_WEBHOOKS = form_data.ENABLE_USER_WEBHOOKS
|
||||
|
||||
request.app.state.config.PENDING_USER_OVERLAY_TITLE = (
|
||||
form_data.PENDING_USER_OVERLAY_TITLE
|
||||
)
|
||||
request.app.state.config.PENDING_USER_OVERLAY_CONTENT = (
|
||||
form_data.PENDING_USER_OVERLAY_CONTENT
|
||||
)
|
||||
|
||||
request.app.state.config.RESPONSE_WATERMARK = form_data.RESPONSE_WATERMARK
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
@@ -764,6 +793,9 @@ async def update_admin_config(
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"ENABLE_NOTES": request.app.state.config.ENABLE_NOTES,
|
||||
"ENABLE_USER_WEBHOOKS": request.app.state.config.ENABLE_USER_WEBHOOKS,
|
||||
"PENDING_USER_OVERLAY_TITLE": request.app.state.config.PENDING_USER_OVERLAY_TITLE,
|
||||
"PENDING_USER_OVERLAY_CONTENT": request.app.state.config.PENDING_USER_OVERLAY_CONTENT,
|
||||
"RESPONSE_WATERMARK": request.app.state.config.RESPONSE_WATERMARK,
|
||||
}
|
||||
|
||||
|
||||
@@ -779,6 +811,7 @@ class LdapServerConfig(BaseModel):
|
||||
search_filters: str = ""
|
||||
use_tls: bool = True
|
||||
certificate_path: Optional[str] = None
|
||||
validate_cert: bool = True
|
||||
ciphers: Optional[str] = "ALL"
|
||||
|
||||
|
||||
@@ -796,6 +829,7 @@ async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
|
||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
@@ -831,6 +865,7 @@ async def update_ldap_server(
|
||||
request.app.state.config.LDAP_SEARCH_FILTERS = form_data.search_filters
|
||||
request.app.state.config.LDAP_USE_TLS = form_data.use_tls
|
||||
request.app.state.config.LDAP_CA_CERT_FILE = form_data.certificate_path
|
||||
request.app.state.config.LDAP_VALIDATE_CERT = form_data.validate_cert
|
||||
request.app.state.config.LDAP_CIPHERS = form_data.ciphers
|
||||
|
||||
return {
|
||||
@@ -845,6 +880,7 @@ async def update_ldap_server(
|
||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||
"validate_cert": request.app.state.config.LDAP_VALIDATE_CERT,
|
||||
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||
}
|
||||
|
||||
|
||||
@@ -74,13 +74,17 @@ class FeedbackUserResponse(FeedbackResponse):
|
||||
@router.get("/feedbacks/all", response_model=list[FeedbackUserResponse])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(),
|
||||
user=UserResponse(**Users.get_user_by_id(feedback.user_id).model_dump()),
|
||||
|
||||
feedback_list = []
|
||||
for feedback in feedbacks:
|
||||
user = Users.get_user_by_id(feedback.user_id)
|
||||
feedback_list.append(
|
||||
FeedbackUserResponse(
|
||||
**feedback.model_dump(),
|
||||
user=UserResponse(**user.model_dump()) if user else None,
|
||||
)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedback_list
|
||||
|
||||
|
||||
@router.delete("/feedbacks/all")
|
||||
@@ -92,12 +96,7 @@ async def delete_all_feedbacks(user=Depends(get_admin_user)):
|
||||
@router.get("/feedbacks/all/export", response_model=list[FeedbackModel])
|
||||
async def get_all_feedbacks(user=Depends(get_admin_user)):
|
||||
feedbacks = Feedbacks.get_all_feedbacks()
|
||||
return [
|
||||
FeedbackModel(
|
||||
**feedback.model_dump(), user=Users.get_user_by_id(feedback.user_id)
|
||||
)
|
||||
for feedback in feedbacks
|
||||
]
|
||||
return feedbacks
|
||||
|
||||
|
||||
@router.get("/feedbacks/user", response_model=list[FeedbackUserResponse])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
@@ -10,6 +11,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
@@ -84,17 +86,44 @@ def has_access_to_file(
|
||||
def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
file_metadata: dict = None,
|
||||
metadata: Optional[dict | str] = Form(None),
|
||||
process: bool = Query(True),
|
||||
internal: bool = False,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
file_metadata = file_metadata if file_metadata else {}
|
||||
if isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"),
|
||||
)
|
||||
file_metadata = metadata if metadata else {}
|
||||
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
file_extension = os.path.splitext(filename)[1]
|
||||
# Remove the leading dot from the file extension
|
||||
file_extension = file_extension[1:] if file_extension else ""
|
||||
|
||||
if (not internal) and request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = [
|
||||
ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext
|
||||
]
|
||||
|
||||
if file_extension not in request.app.state.config.ALLOWED_FILE_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(
|
||||
f"File type {file_extension} is not allowed"
|
||||
),
|
||||
)
|
||||
|
||||
# replace filename with uuid
|
||||
id = str(uuid.uuid4())
|
||||
name = filename
|
||||
@@ -125,33 +154,26 @@ def upload_file(
|
||||
)
|
||||
if process:
|
||||
try:
|
||||
if file.content_type:
|
||||
if file.content_type.startswith("audio/") or file.content_type in {
|
||||
"video/webm"
|
||||
}:
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path, file_metadata)
|
||||
|
||||
if file.content_type.startswith(
|
||||
(
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/x-m4a",
|
||||
"audio/webm",
|
||||
"video/webm",
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif (not file.content_type.startswith(("image/", "video/"))) or (
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external"
|
||||
):
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
else:
|
||||
log.info(
|
||||
f"File type {file.content_type} is not provided, but trying to process anyway"
|
||||
)
|
||||
):
|
||||
file_path = Storage.get_file(file_path)
|
||||
result = transcribe(request, file_path)
|
||||
|
||||
process_file(
|
||||
request,
|
||||
ProcessFileForm(file_id=id, content=result.get("text", "")),
|
||||
user=user,
|
||||
)
|
||||
elif file.content_type not in [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"video/mp4",
|
||||
"video/ogg",
|
||||
"video/quicktime",
|
||||
]:
|
||||
process_file(request, ProcessFileForm(file_id=id), user=user)
|
||||
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
|
||||
@@ -262,11 +262,8 @@ async def get_function_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -290,11 +287,8 @@ async def update_function_valves_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
Valves = function_module.Valves
|
||||
@@ -353,11 +347,8 @@ async def get_function_user_valves_spec_by_id(
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
@@ -377,11 +368,8 @@ async def update_function_user_valves_by_id(
|
||||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
if id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[id]
|
||||
else:
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(id)
|
||||
request.app.state.FUNCTIONS[id] = function_module
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
UserValves = function_module.UserValves
|
||||
|
||||
@@ -333,10 +333,11 @@ def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
{"id": "gpt-image-1", "name": "GPT-IMAGE 1"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini":
|
||||
return [
|
||||
{"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
{"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"},
|
||||
]
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
@@ -450,7 +451,7 @@ def load_url_image_data(url, headers=None):
|
||||
return None
|
||||
|
||||
|
||||
def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
def upload_image(request, image_data, content_type, metadata, user):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
@@ -459,7 +460,7 @@ def upload_image(request, image_metadata, image_data, content_type, user):
|
||||
"content-type": content_type,
|
||||
},
|
||||
)
|
||||
file_item = upload_file(request, file, user, file_metadata=image_metadata)
|
||||
file_item = upload_file(request, file, metadata=metadata, internal=True, user=user)
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
return url
|
||||
|
||||
@@ -526,7 +527,7 @@ async def image_generations(
|
||||
else:
|
||||
image_data, content_type = load_b64_image_data(image["b64_json"])
|
||||
|
||||
url = upload_image(request, data, image_data, content_type, user)
|
||||
url = upload_image(request, image_data, content_type, data, user)
|
||||
images.append({"url": url})
|
||||
return images
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from open_webui.models.knowledge import (
|
||||
KnowledgeUserResponse,
|
||||
)
|
||||
from open_webui.models.files import Files, FileModel, FileMetadataResponse
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
ProcessFileForm,
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
@@ -300,6 +302,22 @@ async def update_config(
|
||||
}
|
||||
|
||||
|
||||
def merge_ollama_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models(request: Request, user: UserModel = None):
|
||||
log.info("get_all_models()")
|
||||
@@ -340,6 +358,8 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "local")
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
model_ids = api_config.get("model_ids", [])
|
||||
@@ -352,31 +372,18 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
)
|
||||
)
|
||||
|
||||
if prefix_id:
|
||||
for model in response.get("models", []):
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
if tags:
|
||||
for model in response.get("models", []):
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
for idx, model_list in enumerate(model_lists):
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
id = model["model"]
|
||||
if id not in merged_models:
|
||||
model["urls"] = [idx]
|
||||
merged_models[id] = model
|
||||
else:
|
||||
merged_models[id]["urls"].append(idx)
|
||||
|
||||
return list(merged_models.values())
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
models = {
|
||||
"models": merge_models_lists(
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
@@ -384,6 +391,19 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||
)
|
||||
}
|
||||
|
||||
loaded_models = await get_ollama_loaded_models(request, user=user)
|
||||
expires_map = {
|
||||
m["name"]: m["expires_at"]
|
||||
for m in loaded_models["models"]
|
||||
if "expires_at" in m
|
||||
}
|
||||
|
||||
for m in models["models"]:
|
||||
if m["name"] in expires_map:
|
||||
# Parse ISO8601 datetime with offset, get unix timestamp as int
|
||||
dt = datetime.fromisoformat(expires_map[m["name"]])
|
||||
m["expires_at"] = int(dt.timestamp())
|
||||
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
@@ -464,6 +484,68 @@ async def get_ollama_tags(
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = []
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and (
|
||||
url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support
|
||||
):
|
||||
request_tasks.append(send_get_request(f"{url}/api/ps", user=user))
|
||||
else:
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
if enable:
|
||||
request_tasks.append(
|
||||
send_get_request(f"{url}/api/ps", key, user=user)
|
||||
)
|
||||
else:
|
||||
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
||||
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
for model in response.get("models", []):
|
||||
if prefix_id:
|
||||
model["model"] = f"{prefix_id}.{model['model']}"
|
||||
|
||||
models = {
|
||||
"models": merge_ollama_models_lists(
|
||||
map(
|
||||
lambda response: response.get("models", []) if response else None,
|
||||
responses,
|
||||
)
|
||||
)
|
||||
}
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
@@ -537,36 +619,74 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
return {"version": False}
|
||||
|
||||
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
request_tasks = [
|
||||
send_get_request(
|
||||
f"{url}/api/ps",
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
url, {}
|
||||
), # Legacy support
|
||||
).get("key", None),
|
||||
user=user,
|
||||
)
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS)
|
||||
]
|
||||
responses = await asyncio.gather(*request_tasks)
|
||||
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/api/unload")
|
||||
async def unload_model(
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
model_name = form_data.name
|
||||
if not model_name:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing 'name' of model to unload."
|
||||
)
|
||||
|
||||
# Refresh/load models if needed, get mapping from name to URLs
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OLLAMA_MODELS
|
||||
|
||||
# Canonicalize model name (if not supplied with version)
|
||||
if ":" not in model_name:
|
||||
model_name = f"{model_name}:latest"
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name)
|
||||
)
|
||||
url_indices = models[model_name]["urls"]
|
||||
|
||||
# Send unload to ALL url_indices
|
||||
results = []
|
||||
errors = []
|
||||
for idx in url_indices:
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(
|
||||
str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
)
|
||||
key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id and model_name.startswith(f"{prefix_id}."):
|
||||
model_name = model_name[len(f"{prefix_id}.") :]
|
||||
|
||||
payload = {"model": model_name, "keep_alive": 0, "prompt": ""}
|
||||
|
||||
try:
|
||||
res = await send_post_request(
|
||||
url=f"{url}/api/generate",
|
||||
payload=json.dumps(payload),
|
||||
stream=False,
|
||||
key=key,
|
||||
user=user,
|
||||
)
|
||||
results.append({"url_idx": idx, "success": True, "response": res})
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to unload model on node {idx}: {e}")
|
||||
errors.append({"url_idx": idx, "success": False, "error": str(e)})
|
||||
|
||||
if len(errors) > 0:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to unload model on {len(errors)} nodes: {errors}",
|
||||
)
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@router.post("/api/pull")
|
||||
@router.post("/api/pull/{url_idx}")
|
||||
async def pull_model(
|
||||
@@ -1585,7 +1705,9 @@ async def upload_model(
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||
|
||||
filename = os.path.basename(file.filename)
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
# --- P1: save file locally ---
|
||||
@@ -1630,13 +1752,13 @@ async def upload_model(
|
||||
os.remove(file_path)
|
||||
|
||||
# Create model in ollama
|
||||
model_name, ext = os.path.splitext(file.filename)
|
||||
model_name, ext = os.path.splitext(filename)
|
||||
log.info(f"Created Model: {model_name}") # DEBUG
|
||||
|
||||
create_payload = {
|
||||
"model": model_name,
|
||||
# Reference the file by its original name => the uploaded blob's digest
|
||||
"files": {file.filename: f"sha256:{file_hash}"},
|
||||
"files": {filename: f"sha256:{file_hash}"},
|
||||
}
|
||||
log.info(f"Model Payload: {create_payload}") # DEBUG
|
||||
|
||||
@@ -1653,7 +1775,7 @@ async def upload_model(
|
||||
done_msg = {
|
||||
"done": True,
|
||||
"blob": f"sha256:{file_hash}",
|
||||
"name": file.filename,
|
||||
"name": filename,
|
||||
"model_created": model_name,
|
||||
}
|
||||
yield f"data: {json.dumps(done_msg)}\n\n"
|
||||
|
||||
@@ -353,21 +353,22 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
connection_type = api_config.get("connection_type", "external")
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
tags = api_config.get("tags", [])
|
||||
|
||||
if prefix_id:
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
if prefix_id:
|
||||
model["id"] = f"{prefix_id}.{model['id']}"
|
||||
|
||||
if tags:
|
||||
for model in (
|
||||
response if isinstance(response, list) else response.get("data", [])
|
||||
):
|
||||
if tags:
|
||||
model["tags"] = tags
|
||||
|
||||
if connection_type:
|
||||
model["connection_type"] = connection_type
|
||||
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
return responses
|
||||
|
||||
@@ -415,6 +416,7 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
||||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"connection_type": model.get("connection_type", "external"),
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
@@ -461,60 +463,74 @@ async def get_models(
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(url_idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
if api_config.get("azure", False):
|
||||
models = {
|
||||
"data": api_config.get("model_ids", []) or [],
|
||||
"object": "list",
|
||||
}
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
models = response_data
|
||||
response_data = await r.json()
|
||||
|
||||
# Check if we're calling OpenAI API based on the URL
|
||||
if "api.openai.com" in url:
|
||||
# Filter models according to the specified conditions
|
||||
response_data["data"] = [
|
||||
model
|
||||
for model in response_data.get("data", [])
|
||||
if not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
models = response_data
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
log.exception(f"Client error: {str(e)}")
|
||||
@@ -536,6 +552,8 @@ class ConnectionVerificationForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
@router.post("/verify")
|
||||
async def verify_connection(
|
||||
@@ -544,39 +562,64 @@ async def verify_connection(
|
||||
url = form_data.url
|
||||
key = form_data.key
|
||||
|
||||
api_config = form_data.config or {}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
||||
) as session:
|
||||
try:
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
|
||||
async with session.get(
|
||||
url=f"{url}/openai/models?api-version={api_version}",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
async with session.get(
|
||||
f"{url}/models",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
# Extract response error details if available
|
||||
error_detail = f"HTTP Error: {r.status}"
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External Error: {res['error']}"
|
||||
raise Exception(error_detail)
|
||||
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
@@ -590,6 +633,63 @@ async def verify_connection(
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
|
||||
def convert_to_azure_payload(
|
||||
url,
|
||||
payload: dict,
|
||||
):
|
||||
model = payload.get("model", "")
|
||||
|
||||
# Filter allowed parameters based on Azure OpenAI API
|
||||
allowed_params = {
|
||||
"messages",
|
||||
"temperature",
|
||||
"role",
|
||||
"content",
|
||||
"contentPart",
|
||||
"contentPartImage",
|
||||
"enhancements",
|
||||
"dataSources",
|
||||
"n",
|
||||
"stream",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"log_probs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"max_completion_tokens",
|
||||
}
|
||||
|
||||
# Special handling for o-series models
|
||||
if model.startswith("o") and model.endswith("-mini"):
|
||||
# Convert max_tokens to max_completion_tokens for o-series models
|
||||
if "max_tokens" in payload:
|
||||
payload["max_completion_tokens"] = payload["max_tokens"]
|
||||
del payload["max_tokens"]
|
||||
|
||||
# Remove temperature if not 1 for o-series models
|
||||
if "temperature" in payload and payload["temperature"] != 1:
|
||||
log.debug(
|
||||
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
|
||||
)
|
||||
del payload["temperature"]
|
||||
|
||||
# Filter out unsupported parameters
|
||||
payload = {k: v for k, v in payload.items() if k in allowed_params}
|
||||
|
||||
url = f"{url}/openai/deployments/{model}"
|
||||
return url, payload
|
||||
|
||||
|
||||
@router.post("/chat/completions")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
@@ -690,6 +790,38 @@ async def generate_chat_completion(
|
||||
convert_logit_bias_input_to_json(payload["logit_bias"])
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
request_url, payload = convert_to_azure_payload(url, payload)
|
||||
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = api_version
|
||||
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
||||
else:
|
||||
request_url = f"{url}/chat/completions"
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
|
||||
payload = json.dumps(payload)
|
||||
|
||||
r = None
|
||||
@@ -704,30 +836,9 @@ async def generate_chat_completion(
|
||||
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/chat/completions",
|
||||
url=request_url,
|
||||
data=payload,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"HTTP-Referer": "https://openwebui.com/",
|
||||
"X-Title": "Open WebUI",
|
||||
}
|
||||
if "openrouter.ai" in url
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
@@ -783,31 +894,53 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
||||
), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if api_config.get("azure", False):
|
||||
headers["api-key"] = key
|
||||
headers["api-version"] = (
|
||||
api_config.get("api_version", "") or "2023-03-15-preview"
|
||||
)
|
||||
|
||||
payload = json.loads(body)
|
||||
url, payload = convert_to_azure_payload(url, payload)
|
||||
body = json.dumps(payload).encode()
|
||||
|
||||
request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}"
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
request_url = f"{url}/{path}"
|
||||
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=f"{url}/{path}",
|
||||
url=request_url,
|
||||
data=body,
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -18,7 +18,7 @@ from pydantic import BaseModel
|
||||
from starlette.responses import FileResponse
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, AIOHTTP_CLIENT_SESSION_SSL
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
@@ -69,7 +69,10 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
@@ -89,6 +92,7 @@ async def process_pipeline_inlet_filter(request, payload, user, models):
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
@@ -118,7 +122,10 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
for filter in sorted_filters:
|
||||
urlIdx = filter.get("urlIdx")
|
||||
if urlIdx is None:
|
||||
|
||||
try:
|
||||
urlIdx = int(urlIdx)
|
||||
except:
|
||||
continue
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
@@ -138,6 +145,7 @@ async def process_pipeline_outlet_filter(request, payload, user, models):
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json=request_data,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
payload = await response.json()
|
||||
response.raise_for_status()
|
||||
@@ -197,8 +205,10 @@ async def upload_pipeline(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
|
||||
filename = os.path.basename(file.filename)
|
||||
|
||||
# Check if the uploaded file is a python file
|
||||
if not (file.filename and file.filename.endswith(".py")):
|
||||
if not (filename and filename.endswith(".py")):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
@@ -206,7 +216,7 @@ async def upload_pipeline(
|
||||
|
||||
upload_folder = f"{CACHE_DIR}/pipelines"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
file_path = os.path.join(upload_folder, file.filename)
|
||||
file_path = os.path.join(upload_folder, filename)
|
||||
|
||||
r = None
|
||||
try:
|
||||
|
||||
@@ -36,7 +36,7 @@ from open_webui.models.knowledge import Knowledges
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
# Document loaders
|
||||
from open_webui.retrieval.loaders.main import Loader
|
||||
@@ -352,10 +352,13 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
@@ -371,6 +374,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
# File upload settings
|
||||
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
||||
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
||||
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
||||
# Integration settings
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
@@ -383,6 +387,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
@@ -435,6 +440,7 @@ class WebConfig(BaseModel):
|
||||
WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None
|
||||
WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = []
|
||||
BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
|
||||
BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None
|
||||
SEARXNG_QUERY_URL: Optional[str] = None
|
||||
YACY_QUERY_URL: Optional[str] = None
|
||||
YACY_USERNAME: Optional[str] = None
|
||||
@@ -492,10 +498,14 @@ class ConfigForm(BaseModel):
|
||||
# Content extraction settings
|
||||
CONTENT_EXTRACTION_ENGINE: Optional[str] = None
|
||||
PDF_EXTRACT_IMAGES: Optional[bool] = None
|
||||
EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None
|
||||
|
||||
TIKA_SERVER_URL: Optional[str] = None
|
||||
DOCLING_SERVER_URL: Optional[str] = None
|
||||
DOCLING_OCR_ENGINE: Optional[str] = None
|
||||
DOCLING_OCR_LANG: Optional[str] = None
|
||||
DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None
|
||||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||||
@@ -514,6 +524,7 @@ class ConfigForm(BaseModel):
|
||||
# File upload settings
|
||||
FILE_MAX_SIZE: Optional[int] = None
|
||||
FILE_MAX_COUNT: Optional[int] = None
|
||||
ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None
|
||||
|
||||
# Integration settings
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION: Optional[bool] = None
|
||||
@@ -581,6 +592,16 @@ async def update_rag_config(
|
||||
if form_data.PDF_EXTRACT_IMAGES is not None
|
||||
else request.app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = (
|
||||
form_data.EXTERNAL_DOCUMENT_LOADER_URL
|
||||
if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None
|
||||
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL
|
||||
)
|
||||
request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = (
|
||||
form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||
if form_data.EXTERNAL_DOCUMENT_LOADER_API_KEY is not None
|
||||
else request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY
|
||||
)
|
||||
request.app.state.config.TIKA_SERVER_URL = (
|
||||
form_data.TIKA_SERVER_URL
|
||||
if form_data.TIKA_SERVER_URL is not None
|
||||
@@ -601,6 +622,13 @@ async def update_rag_config(
|
||||
if form_data.DOCLING_OCR_LANG is not None
|
||||
else request.app.state.config.DOCLING_OCR_LANG
|
||||
)
|
||||
|
||||
request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = (
|
||||
form_data.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None
|
||||
else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION
|
||||
)
|
||||
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None
|
||||
@@ -688,6 +716,11 @@ async def update_rag_config(
|
||||
if form_data.FILE_MAX_COUNT is not None
|
||||
else request.app.state.config.FILE_MAX_COUNT
|
||||
)
|
||||
request.app.state.config.ALLOWED_FILE_EXTENSIONS = (
|
||||
form_data.ALLOWED_FILE_EXTENSIONS
|
||||
if form_data.ALLOWED_FILE_EXTENSIONS is not None
|
||||
else request.app.state.config.ALLOWED_FILE_EXTENSIONS
|
||||
)
|
||||
|
||||
# Integration settings
|
||||
request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
|
||||
@@ -720,6 +753,9 @@ async def update_rag_config(
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
|
||||
)
|
||||
request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = (
|
||||
form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER
|
||||
)
|
||||
request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL
|
||||
request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL
|
||||
request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME
|
||||
@@ -809,10 +845,13 @@ async def update_rag_config(
|
||||
# Content extraction settings
|
||||
"CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
"EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
"EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
"TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL,
|
||||
"DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL,
|
||||
"DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
"DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG,
|
||||
"DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||
@@ -828,6 +867,7 @@ async def update_rag_config(
|
||||
# File upload settings
|
||||
"FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE,
|
||||
"FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT,
|
||||
"ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS,
|
||||
# Integration settings
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
|
||||
"ENABLE_ONEDRIVE_INTEGRATION": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
|
||||
@@ -840,6 +880,7 @@ async def update_rag_config(
|
||||
"WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
"BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
|
||||
"BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER,
|
||||
"SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL,
|
||||
"YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL,
|
||||
"YACY_USERNAME": request.app.state.config.YACY_USERNAME,
|
||||
@@ -1129,10 +1170,13 @@ def process_file(
|
||||
file_path = Storage.get_file(file_path)
|
||||
loader = Loader(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL,
|
||||
EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL,
|
||||
DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE,
|
||||
DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG,
|
||||
DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
@@ -1640,13 +1684,29 @@ async def process_web_search(
|
||||
)
|
||||
|
||||
try:
|
||||
loader = get_web_loader(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER:
|
||||
docs = [
|
||||
Document(
|
||||
page_content=result.snippet,
|
||||
metadata={
|
||||
"source": result.link,
|
||||
"title": result.title,
|
||||
"snippet": result.snippet,
|
||||
"link": result.link,
|
||||
},
|
||||
)
|
||||
for result in search_results
|
||||
if hasattr(result, "snippet")
|
||||
]
|
||||
else:
|
||||
loader = get_web_loader(
|
||||
urls,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
|
||||
)
|
||||
docs = await loader.aload()
|
||||
|
||||
urls = [
|
||||
doc.metadata.get("source") for doc in docs if doc.metadata.get("source")
|
||||
] # only keep the urls returned by the loader
|
||||
|
||||
@@ -20,10 +20,7 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -195,15 +192,19 @@ async def generate_title(
|
||||
},
|
||||
)
|
||||
|
||||
max_tokens = (
|
||||
models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 1000}
|
||||
{"max_tokens": max_tokens}
|
||||
if models[task_model_id].get("owned_by") == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 1000,
|
||||
"max_completion_tokens": max_tokens,
|
||||
}
|
||||
),
|
||||
"metadata": {
|
||||
|
||||
@@ -13,6 +13,8 @@ import pytz
|
||||
from pytz import UTC
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -194,7 +196,17 @@ def get_current_user(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
|
||||
return get_current_user_by_api_key(token)
|
||||
user = get_current_user_by_api_key(token)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
return user
|
||||
|
||||
# auth by jwt token
|
||||
try:
|
||||
@@ -213,6 +225,14 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "jwt")
|
||||
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
@@ -234,6 +254,14 @@ def get_current_user_by_api_key(api_key: str):
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
return user
|
||||
|
||||
@@ -309,6 +309,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"filter_ids": data.get("filter_ids", []),
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
@@ -330,7 +331,9 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
result, _ = await process_filter_functions(
|
||||
@@ -389,11 +392,8 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
}
|
||||
)
|
||||
|
||||
if action_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[action_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(action_id)
|
||||
request.app.state.FUNCTIONS[action_id] = function_module
|
||||
function_module, _, _ = load_function_module_by_id(action_id)
|
||||
request.app.state.FUNCTIONS[action_id] = function_module
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
|
||||
@@ -9,7 +9,18 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_sorted_filter_ids(model: dict):
|
||||
def get_function_module(request, function_id):
|
||||
"""
|
||||
Get the function module by its ID.
|
||||
"""
|
||||
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
|
||||
return function_module
|
||||
|
||||
|
||||
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
if function is not None:
|
||||
@@ -21,14 +32,23 @@ def get_sorted_filter_ids(model: dict):
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
|
||||
enabled_filter_ids = [
|
||||
active_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in enabled_filter_ids]
|
||||
for filter_id in active_filter_ids:
|
||||
function_module = get_function_module(request, filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None) and (
|
||||
filter_id not in enabled_filter_ids
|
||||
):
|
||||
active_filter_ids.remove(filter_id)
|
||||
continue
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
return filter_ids
|
||||
|
||||
|
||||
@@ -43,12 +63,7 @@ async def process_filter_functions(
|
||||
if not filter:
|
||||
continue
|
||||
|
||||
if filter_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(filter_id)
|
||||
request.app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
function_module = get_function_module(request, filter_id)
|
||||
# Prepare handler function
|
||||
handler = getattr(function_module, filter_type, None)
|
||||
if not handler:
|
||||
|
||||
@@ -43,6 +43,7 @@ from open_webui.routers.pipelines import (
|
||||
process_pipeline_outlet_filter,
|
||||
)
|
||||
from open_webui.routers.files import upload_file
|
||||
from open_webui.routers.memories import query_memory, QueryMemoryForm
|
||||
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
@@ -253,7 +254,12 @@ async def chat_completion_tools_handler(
|
||||
"name": (f"TOOL:{tool_name}"),
|
||||
},
|
||||
"document": [tool_result],
|
||||
"metadata": [{"source": (f"TOOL:{tool_name}")}],
|
||||
"metadata": [
|
||||
{
|
||||
"source": (f"TOOL:{tool_name}"),
|
||||
"parameters": tool_function_params,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -292,6 +298,38 @@ async def chat_completion_tools_handler(
|
||||
return body, {"sources": sources}
|
||||
|
||||
|
||||
async def chat_memory_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
results = await query_memory(
|
||||
request,
|
||||
QueryMemoryForm(
|
||||
**{"content": get_last_user_message(form_data["messages"]), "k": 3}
|
||||
),
|
||||
user,
|
||||
)
|
||||
|
||||
user_context = ""
|
||||
if results and hasattr(results, "documents"):
|
||||
if results.documents and len(results.documents) > 0:
|
||||
for doc_idx, doc in enumerate(results.documents[0]):
|
||||
created_at_date = "Unknown Date"
|
||||
|
||||
if results.metadatas[0][doc_idx].get("created_at"):
|
||||
created_at_timestamp = results.metadatas[0][doc_idx]["created_at"]
|
||||
created_at_date = time.strftime(
|
||||
"%Y-%m-%d", time.localtime(created_at_timestamp)
|
||||
)
|
||||
|
||||
user_context += f"{doc_idx + 1}. [{created_at_date}] {doc}\n"
|
||||
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
f"User Context:\n{user_context}\n", form_data["messages"], append=True
|
||||
)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
async def chat_web_search_handler(
|
||||
request: Request, form_data: dict, extra_params: dict, user
|
||||
):
|
||||
@@ -342,6 +380,11 @@ async def chat_web_search_handler(
|
||||
log.exception(e)
|
||||
queries = [user_message]
|
||||
|
||||
# Check if generated queries are empty
|
||||
if len(queries) == 1 and queries[0].strip() == "":
|
||||
queries = [user_message]
|
||||
|
||||
# Check if queries are not found
|
||||
if len(queries) == 0:
|
||||
await event_emitter(
|
||||
{
|
||||
@@ -653,7 +696,7 @@ def apply_params_to_form_data(form_data, model):
|
||||
convert_logit_bias_input_to_json(params["logit_bias"])
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing logit_bias: {e}")
|
||||
log.exception(f"Error parsing logit_bias: {e}")
|
||||
|
||||
return form_data
|
||||
|
||||
@@ -751,9 +794,12 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
raise e
|
||||
|
||||
try:
|
||||
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
form_data, flags = await process_filter_functions(
|
||||
@@ -768,6 +814,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
|
||||
features = form_data.pop("features", None)
|
||||
if features:
|
||||
if "memory" in features and features["memory"]:
|
||||
form_data = await chat_memory_handler(
|
||||
request, form_data, extra_params, user
|
||||
)
|
||||
|
||||
if "web_search" in features and features["web_search"]:
|
||||
form_data = await chat_web_search_handler(
|
||||
request, form_data, extra_params, user
|
||||
@@ -870,6 +921,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
for doc_context, doc_meta in zip(
|
||||
source["document"], source["metadata"]
|
||||
):
|
||||
source_name = source.get("source", {}).get("name", None)
|
||||
citation_id = (
|
||||
doc_meta.get("source", None)
|
||||
or source.get("source", {}).get("id", None)
|
||||
@@ -877,7 +929,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
)
|
||||
if citation_id not in citation_idx:
|
||||
citation_idx[citation_id] = len(citation_idx) + 1
|
||||
context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n'
|
||||
context_string += (
|
||||
f'<source id="{citation_idx[citation_id]}"'
|
||||
+ (f' name="{source_name}"' if source_name else "")
|
||||
+ f">{doc_context}</source>\n"
|
||||
)
|
||||
|
||||
context_string = context_string.strip()
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
@@ -944,21 +1000,36 @@ async def process_chat_response(
|
||||
message = message_map.get(metadata["message_id"]) if message_map else None
|
||||
|
||||
if message:
|
||||
messages = get_message_list(message_map, message.get("id"))
|
||||
message_list = get_message_list(message_map, message.get("id"))
|
||||
|
||||
# Remove reasoning details and files from the messages.
|
||||
# Remove details tags and files from the messages.
|
||||
# as get_message_list creates a new list, it does not affect
|
||||
# the original messages outside of this handler
|
||||
for message in messages:
|
||||
message["content"] = re.sub(
|
||||
r"<details\s+type=\"reasoning\"[^>]*>.*?<\/details>",
|
||||
"",
|
||||
message["content"],
|
||||
flags=re.S,
|
||||
).strip()
|
||||
|
||||
if message.get("files"):
|
||||
message["files"] = []
|
||||
messages = []
|
||||
for message in message_list:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
content = item["text"]
|
||||
break
|
||||
|
||||
if isinstance(content, str):
|
||||
content = re.sub(
|
||||
r"<details\b[^>]*>.*?<\/details>|!\[.*?\]\(.*?\)",
|
||||
"",
|
||||
content,
|
||||
flags=re.S | re.I,
|
||||
).strip()
|
||||
|
||||
messages.append(
|
||||
{
|
||||
**message,
|
||||
"role": message["role"],
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
if tasks and messages:
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
@@ -1171,7 +1242,9 @@ async def process_chat_response(
|
||||
}
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(model)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
|
||||
@@ -130,7 +130,9 @@ def prepend_to_first_user_message_content(
|
||||
return messages
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
def add_or_update_system_message(
|
||||
content: str, messages: list[dict], append: bool = False
|
||||
):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
or updates the existing system message at the beginning.
|
||||
@@ -141,7 +143,10 @@ def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
if append:
|
||||
messages[0]["content"] = f"{messages[0]['content']}\n{content}"
|
||||
else:
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
else:
|
||||
# Insert at the beginning
|
||||
messages.insert(0, {"role": "system", "content": content})
|
||||
|
||||
@@ -49,6 +49,7 @@ async def get_all_base_models(request: Request, user: UserModel = None):
|
||||
"created": int(time.time()),
|
||||
"owned_by": "ollama",
|
||||
"ollama": model,
|
||||
"connection_type": model.get("connection_type", "local"),
|
||||
"tags": model.get("tags", []),
|
||||
}
|
||||
for model in ollama_models["models"]
|
||||
@@ -110,6 +111,14 @@ async def get_all_models(request, user: UserModel = None):
|
||||
for function in Functions.get_functions_by_type("action", active_only=True)
|
||||
]
|
||||
|
||||
global_filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
]
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
]
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id is None:
|
||||
@@ -125,13 +134,20 @@ async def get_all_models(request, user: UserModel = None):
|
||||
model["name"] = custom_model.name
|
||||
model["info"] = custom_model.model_dump()
|
||||
|
||||
# Set action_ids and filter_ids
|
||||
action_ids = []
|
||||
filter_ids = []
|
||||
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
action_ids.extend(
|
||||
model["info"]["meta"].get("actionIds", [])
|
||||
)
|
||||
filter_ids.extend(
|
||||
model["info"]["meta"].get("filterIds", [])
|
||||
)
|
||||
|
||||
model["action_ids"] = action_ids
|
||||
model["filter_ids"] = filter_ids
|
||||
else:
|
||||
models.remove(model)
|
||||
|
||||
@@ -140,7 +156,9 @@ async def get_all_models(request, user: UserModel = None):
|
||||
):
|
||||
owned_by = "openai"
|
||||
pipe = None
|
||||
|
||||
action_ids = []
|
||||
filter_ids = []
|
||||
|
||||
for model in models:
|
||||
if (
|
||||
@@ -154,9 +172,13 @@ async def get_all_models(request, user: UserModel = None):
|
||||
|
||||
if custom_model.meta:
|
||||
meta = custom_model.meta.model_dump()
|
||||
|
||||
if "actionIds" in meta:
|
||||
action_ids.extend(meta["actionIds"])
|
||||
|
||||
if "filterIds" in meta:
|
||||
filter_ids.extend(meta["filterIds"])
|
||||
|
||||
models.append(
|
||||
{
|
||||
"id": f"{custom_model.id}",
|
||||
@@ -168,6 +190,7 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"preset": True,
|
||||
**({"pipe": pipe} if pipe is not None else {}),
|
||||
"action_ids": action_ids,
|
||||
"filter_ids": filter_ids,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -181,8 +204,11 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"id": f"{function.id}.{action['id']}",
|
||||
"name": action.get("name", f"{function.name} ({action['id']})"),
|
||||
"description": function.meta.description,
|
||||
"icon_url": action.get(
|
||||
"icon_url", function.meta.manifest.get("icon_url", None)
|
||||
"icon": action.get(
|
||||
"icon_url",
|
||||
function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
),
|
||||
}
|
||||
for action in actions
|
||||
@@ -193,16 +219,28 @@ async def get_all_models(request, user: UserModel = None):
|
||||
"id": function.id,
|
||||
"name": function.name,
|
||||
"description": function.meta.description,
|
||||
"icon_url": function.meta.manifest.get("icon_url", None),
|
||||
"icon": function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
}
|
||||
]
|
||||
|
||||
# Process filter_ids to get the filters
|
||||
def get_filter_items_from_module(function, module):
|
||||
return [
|
||||
{
|
||||
"id": function.id,
|
||||
"name": function.name,
|
||||
"description": function.meta.description,
|
||||
"icon": function.meta.manifest.get("icon_url", None)
|
||||
or getattr(module, "icon_url", None)
|
||||
or getattr(module, "icon", None),
|
||||
}
|
||||
]
|
||||
|
||||
def get_function_module_by_id(function_id):
|
||||
if function_id in request.app.state.FUNCTIONS:
|
||||
function_module = request.app.state.FUNCTIONS[function_id]
|
||||
else:
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
function_module, _, _ = load_function_module_by_id(function_id)
|
||||
request.app.state.FUNCTIONS[function_id] = function_module
|
||||
return function_module
|
||||
|
||||
for model in models:
|
||||
@@ -211,6 +249,11 @@ async def get_all_models(request, user: UserModel = None):
|
||||
for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
|
||||
if action_id in enabled_action_ids
|
||||
]
|
||||
filter_ids = [
|
||||
filter_id
|
||||
for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
|
||||
if filter_id in enabled_filter_ids
|
||||
]
|
||||
|
||||
model["actions"] = []
|
||||
for action_id in action_ids:
|
||||
@@ -222,6 +265,20 @@ async def get_all_models(request, user: UserModel = None):
|
||||
model["actions"].extend(
|
||||
get_action_items_from_module(action_function, function_module)
|
||||
)
|
||||
|
||||
model["filters"] = []
|
||||
for filter_id in filter_ids:
|
||||
filter_function = Functions.get_function_by_id(filter_id)
|
||||
if filter_function is None:
|
||||
raise Exception(f"Filter not found: {filter_id}")
|
||||
|
||||
function_module = get_function_module_by_id(filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None):
|
||||
model["filters"].extend(
|
||||
get_filter_items_from_module(filter_function, function_module)
|
||||
)
|
||||
|
||||
log.debug(f"get_all_models() returned {len(models)} models")
|
||||
|
||||
request.app.state.MODELS = {model["id"]: model for model in models}
|
||||
|
||||
@@ -41,6 +41,7 @@ from open_webui.config import (
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_SSL,
|
||||
WEBUI_NAME,
|
||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||
WEBUI_AUTH_COOKIE_SECURE,
|
||||
@@ -305,8 +306,10 @@ class OAuthManager:
|
||||
get_kwargs["headers"] = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url, **get_kwargs) as resp:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
picture_url, **get_kwargs, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||
@@ -371,7 +374,9 @@ class OAuthManager:
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/user/emails", headers=headers
|
||||
"https://api.github.com/user/emails",
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
emails = await resp.json()
|
||||
@@ -531,5 +536,10 @@ class OAuthManager:
|
||||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
|
||||
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
|
||||
if redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||
|
||||
@@ -57,6 +57,7 @@ def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": float,
|
||||
"min_p": float,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": float,
|
||||
"presence_penalty": float,
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_task_model_id(
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id].get("owned_by") == "ollama":
|
||||
if models[task_model_id].get("connection_type") == "local":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
|
||||
@@ -37,6 +37,7 @@ from open_webui.models.tools import Tools
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.utils.plugin import load_tool_module_by_id
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA,
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
@@ -44,6 +45,7 @@ from open_webui.env import (
|
||||
import copy
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
def get_async_tool_function_and_apply_extra_params(
|
||||
@@ -158,7 +160,7 @@ def get_tools(
|
||||
# TODO: Fix hack for OpenAI API
|
||||
# Some times breaks OpenAI but others don't. Leaving the comment
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
if val.get("type") == "str":
|
||||
val["type"] = "string"
|
||||
|
||||
# Remove internal reserved parameters (e.g. __id__, __user__)
|
||||
@@ -477,7 +479,7 @@ async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
"specs": convert_openapi_to_tool_payload(res),
|
||||
}
|
||||
|
||||
print("Fetched data:", data)
|
||||
log.info("Fetched data:", data)
|
||||
return data
|
||||
|
||||
|
||||
@@ -510,7 +512,7 @@ async def get_tool_servers_data(
|
||||
results = []
|
||||
for (idx, server, url, _), response in zip(server_entries, responses):
|
||||
if isinstance(response, Exception):
|
||||
print(f"Failed to connect to {url} OpenAPI tool server")
|
||||
log.error(f"Failed to connect to {url} OpenAPI tool server")
|
||||
continue
|
||||
|
||||
results.append(
|
||||
@@ -620,5 +622,5 @@ async def execute_tool_server(
|
||||
|
||||
except Exception as err:
|
||||
error = str(err)
|
||||
print("API Request Error:", error)
|
||||
log.exception("API Request Error:", error)
|
||||
return {"error": error}
|
||||
|
||||
@@ -15,7 +15,7 @@ aiofiles
|
||||
|
||||
sqlalchemy==2.0.38
|
||||
alembic==1.14.0
|
||||
peewee==3.17.9
|
||||
peewee==3.18.1
|
||||
peewee-migrate==1.12.2
|
||||
psycopg2-binary==2.9.9
|
||||
pgvector==0.4.0
|
||||
@@ -37,7 +37,8 @@ asgiref==3.8.1
|
||||
# AI libraries
|
||||
openai
|
||||
anthropic
|
||||
google-generativeai==0.8.4
|
||||
google-genai==1.15.0
|
||||
google-generativeai==0.8.5
|
||||
tiktoken
|
||||
|
||||
langchain==0.3.24
|
||||
@@ -98,7 +99,7 @@ pytube==15.0.0
|
||||
|
||||
extract_msg
|
||||
pydub
|
||||
duckduckgo-search~=8.0.0
|
||||
duckduckgo-search==8.0.2
|
||||
|
||||
## Google Drive
|
||||
google-api-python-client
|
||||
|
||||
Reference in New Issue
Block a user