This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 18:46:29 -08:00
parent 3bda1a8b88
commit ccdf51588e
2 changed files with 70 additions and 68 deletions

View File

@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.routers import ( from open_webui.routers import (
audio, audio,
images, images,
@ -63,35 +70,19 @@ from open_webui.routers import (
users, users,
utils, utils,
) )
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.routers.retrieval import ( from open_webui.routers.retrieval import (
get_embedding_function, get_embedding_function,
update_embedding_model, get_ef,
update_reranking_model, get_rf,
) )
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.internal.db import Session from open_webui.internal.db import Session
from open_webui.routers.webui import (
app as webui_app,
generate_function_chat_completion,
get_all_models as get_open_webui_models,
)
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users from open_webui.models.users import UserModel, Users
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.constants import TASKS from open_webui.constants import TASKS
@ -279,7 +270,7 @@ from open_webui.env import (
OFFLINE_MODE, OFFLINE_MODE,
) )
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.misc import ( from open_webui.utils.misc import (
add_or_update_system_message, add_or_update_system_message,
get_last_user_message, get_last_user_message,
@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.EMBEDDING_FUNCTION = None app.state.EMBEDDING_FUNCTION = None
app.state.sentence_transformer_ef = None app.state.ef = None
app.state.sentence_transformer_rf = None app.state.rf = None
app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.YOUTUBE_LOADER_TRANSLATION = None
@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.ef,
( (
app.state.config.OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL else app.state.config.RAG_OLLAMA_BASE_URL
), ),
( (
app.state.config.OPENAI_API_KEY app.state.config.RAG_OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai" if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY else app.state.config.RAG_OLLAMA_API_KEY
), ),
app.state.config.RAG_EMBEDDING_BATCH_SIZE, app.state.config.RAG_EMBEDDING_BATCH_SIZE,
) )
update_embedding_model( try:
app.state.config.RAG_EMBEDDING_MODEL, app.state.ef = get_ef(
RAG_EMBEDDING_MODEL_AUTO_UPDATE, app.state.config.RAG_EMBEDDING_ENGINE,
) app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model( app.state.rf = get_rf(
app.state.config.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
except Exception as e:
log.error(f"Error updating models: {e}")
pass
######################################## ########################################
@ -990,11 +986,11 @@ async def chat_completion_files_handler(
sources = get_sources_from_files( sources = get_sources_from_files(
files=files, files=files,
queries=queries, queries=queries,
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=retrieval_app.state.config.TOP_K, k=app.state.config.TOP_K,
reranking_function=retrieval_app.state.sentence_transformer_rf, reranking_function=app.state.rf,
r=retrieval_app.state.config.RELEVANCE_THRESHOLD, r=app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH,
) )
log.debug(f"rag_contexts:sources: {sources}") log.debug(f"rag_contexts:sources: {sources}")

View File

@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
########################################## ##########################################
def update_embedding_model( def get_ef(
request: Request, engine: str,
embedding_model: str, embedding_model: str,
auto_update: bool = False, auto_update: bool = False,
): ):
if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "": ef = None
if embedding_model and engine == "":
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
try: try:
request.app.state.sentence_transformer_ef = SentenceTransformer( ef = SentenceTransformer(
get_model_path(embedding_model, auto_update), get_model_path(embedding_model, auto_update),
device=DEVICE_TYPE, device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
) )
except Exception as e: except Exception as e:
log.debug(f"Error loading SentenceTransformer: {e}") log.debug(f"Error loading SentenceTransformer: {e}")
request.app.state.sentence_transformer_ef = None
else: return ef
request.app.state.sentence_transformer_ef = None
def update_reranking_model( def get_rf(
request: Request,
reranking_model: str, reranking_model: str,
auto_update: bool = False, auto_update: bool = False,
): ):
rf = None
if reranking_model: if reranking_model:
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]): if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
try: try:
from open_webui.retrieval.models.colbert import ColBERT from open_webui.retrieval.models.colbert import ColBERT
request.app.state.sentence_transformer_rf = ColBERT( rf = ColBERT(
get_model_path(reranking_model, auto_update), get_model_path(reranking_model, auto_update),
env="docker" if DOCKER else None, env="docker" if DOCKER else None,
) )
except Exception as e: except Exception as e:
log.error(f"ColBERT: {e}") log.error(f"ColBERT: {e}")
request.app.state.sentence_transformer_rf = None raise Exception(ERROR_MESSAGES.DEFAULT(e))
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
else: else:
import sentence_transformers import sentence_transformers
try: try:
request.app.state.sentence_transformer_rf = ( rf = sentence_transformers.CrossEncoder(
sentence_transformers.CrossEncoder( get_model_path(reranking_model, auto_update),
get_model_path(reranking_model, auto_update), device=DEVICE_TYPE,
device=DEVICE_TYPE, trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
) )
except: except:
log.error("CrossEncoder error") log.error("CrossEncoder error")
request.app.state.sentence_transformer_rf = None raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False return rf
else:
request.app.state.sentence_transformer_rf = None
########################################## ##########################################
@ -261,12 +257,15 @@ async def update_embedding_config(
form_data.embedding_batch_size form_data.embedding_batch_size
) )
update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL) request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
)
request.app.state.EMBEDDING_FUNCTION = get_embedding_function( request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.sentence_transformer_ef, request.app.state.ef,
( (
request.app.state.config.OPENAI_API_BASE_URL request.app.state.config.OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@ -316,7 +315,14 @@ async def update_reranking_config(
try: try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True) try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
return { return {
"status": True, "status": True,
@ -739,7 +745,7 @@ def save_docs_to_vector_db(
embedding_function = get_embedding_function( embedding_function = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.sentence_transformer_ef, request.app.state.ef,
( (
request.app.state.config.OPENAI_API_BASE_URL request.app.state.config.OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
@ -1286,7 +1292,7 @@ def query_doc_handler(
query=form_data.query, query=form_data.query,
embedding_function=request.app.state.EMBEDDING_FUNCTION, embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.sentence_transformer_rf, reranking_function=request.app.state.rf,
r=( r=(
form_data.r form_data.r
if form_data.r if form_data.r
@ -1328,7 +1334,7 @@ def query_collection_handler(
queries=[form_data.query], queries=[form_data.query],
embedding_function=request.app.state.EMBEDDING_FUNCTION, embedding_function=request.app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.sentence_transformer_rf, reranking_function=request.app.state.rf,
r=( r=(
form_data.r form_data.r
if form_data.r if form_data.r