From 380e080f6d30be56fc06183db2d205b0db5bf9cb Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 3 Nov 2024 01:08:04 -0800 Subject: [PATCH] fix: do not crash on invalid embedding model --- backend/open_webui/apps/retrieval/main.py | 43 ++++++++++++++--------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 21953b4d7..091151646 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -101,7 +101,13 @@ from open_webui.config import ( AppConfig, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, BING_SEARCH_V7_ENDPOINT, BING_SEARCH_V7_SUBSCRIPTION_KEY +from open_webui.env import ( + SRC_LOG_LEVELS, + DEVICE_TYPE, + DOCKER, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, +) from open_webui.utils.misc import ( calculate_sha256, calculate_sha256_string, @@ -176,6 +182,7 @@ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS + def update_embedding_model( embedding_model: str, auto_update: bool = False, @@ -183,11 +190,15 @@ def update_embedding_model( if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": from sentence_transformers import SentenceTransformer - app.state.sentence_transformer_ef = SentenceTransformer( - get_model_path(embedding_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - ) + try: + app.state.sentence_transformer_ef = SentenceTransformer( + get_model_path(embedding_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + ) + except Exception as e: + log.debug(f"Error loading SentenceTransformer: {e}") + app.state.sentence_transformer_ef = None else: app.state.sentence_transformer_ef = None @@ -637,23 +648,21 @@ async def update_query_settings( #################################### -def _get_docs_info( - docs: list[Document] -) -> str: +def _get_docs_info(docs: list[Document]) -> str: docs_info = set() # Trying to select relevant metadata identifying the document. for doc in docs: - metadata = getattr(doc, 'metadata', {}) - doc_name = metadata.get('name', '') + metadata = getattr(doc, "metadata", {}) + doc_name = metadata.get("name", "") if not doc_name: - doc_name = metadata.get('title', '') + doc_name = metadata.get("title", "") if not doc_name: - doc_name = metadata.get('source', '') + doc_name = metadata.get("source", "") if doc_name: docs_info.add(doc_name) - return ', '.join(docs_info) + return ", ".join(docs_info) def save_docs_to_vector_db( @@ -664,7 +673,9 @@ def save_docs_to_vector_db( split: bool = True, add: bool = False, ) -> bool: - log.info(f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}") + log.info( + f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}" + ) # Check if entries with the same hash (metadata.hash) already exist if metadata and "hash" in metadata: @@ -1155,7 +1166,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) elif engine == "bing": return search_bing( - BING_SEARCH_V7_SUBSCRIPTION_KEY, + BING_SEARCH_V7_SUBSCRIPTION_KEY, BING_SEARCH_V7_ENDPOINT, str(DEFAULT_LOCALE), query,