fix: do not crash on invalid embedding model

This commit is contained in:
Timothy J. Baek 2024-11-03 01:08:04 -08:00
parent 6027c13227
commit 380e080f6d

View File

@ -101,7 +101,13 @@ from open_webui.config import (
AppConfig, AppConfig,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE, DOCKER, BING_SEARCH_V7_ENDPOINT, BING_SEARCH_V7_SUBSCRIPTION_KEY from open_webui.env import (
SRC_LOG_LEVELS,
DEVICE_TYPE,
DOCKER,
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
)
from open_webui.utils.misc import ( from open_webui.utils.misc import (
calculate_sha256, calculate_sha256,
calculate_sha256_string, calculate_sha256_string,
@ -176,6 +182,7 @@ app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
def update_embedding_model( def update_embedding_model(
embedding_model: str, embedding_model: str,
auto_update: bool = False, auto_update: bool = False,
@ -183,11 +190,15 @@ def update_embedding_model(
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
app.state.sentence_transformer_ef = SentenceTransformer( try:
get_model_path(embedding_model, auto_update), app.state.sentence_transformer_ef = SentenceTransformer(
device=DEVICE_TYPE, get_model_path(embedding_model, auto_update),
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, device=DEVICE_TYPE,
) trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}")
app.state.sentence_transformer_ef = None
else: else:
app.state.sentence_transformer_ef = None app.state.sentence_transformer_ef = None
@ -637,23 +648,21 @@ async def update_query_settings(
#################################### ####################################
def _get_docs_info( def _get_docs_info(docs: list[Document]) -> str:
docs: list[Document]
) -> str:
docs_info = set() docs_info = set()
# Trying to select relevant metadata identifying the document. # Trying to select relevant metadata identifying the document.
for doc in docs: for doc in docs:
metadata = getattr(doc, 'metadata', {}) metadata = getattr(doc, "metadata", {})
doc_name = metadata.get('name', '') doc_name = metadata.get("name", "")
if not doc_name: if not doc_name:
doc_name = metadata.get('title', '') doc_name = metadata.get("title", "")
if not doc_name: if not doc_name:
doc_name = metadata.get('source', '') doc_name = metadata.get("source", "")
if doc_name: if doc_name:
docs_info.add(doc_name) docs_info.add(doc_name)
return ', '.join(docs_info) return ", ".join(docs_info)
def save_docs_to_vector_db( def save_docs_to_vector_db(
@ -664,7 +673,9 @@ def save_docs_to_vector_db(
split: bool = True, split: bool = True,
add: bool = False, add: bool = False,
) -> bool: ) -> bool:
log.info(f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}") log.info(
f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
)
# Check if entries with the same hash (metadata.hash) already exist # Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata: if metadata and "hash" in metadata:
@ -1155,7 +1166,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
elif engine == "bing": elif engine == "bing":
return search_bing( return search_bing(
BING_SEARCH_V7_SUBSCRIPTION_KEY, BING_SEARCH_V7_SUBSCRIPTION_KEY,
BING_SEARCH_V7_ENDPOINT, BING_SEARCH_V7_ENDPOINT,
str(DEFAULT_LOCALE), str(DEFAULT_LOCALE),
query, query,