This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 18:08:55 -08:00
parent 867c4bc0d0
commit b3987ad41e
2 changed files with 36 additions and 32 deletions

View File

@ -62,8 +62,12 @@ from open_webui.routers import (
users,
utils,
)
from open_webui.retrieval.utils import get_sources_from_files
from open_webui.routers.retrieval import (
get_embedding_function,
update_embedding_model,
update_reranking_model,
)
from open_webui.socket.main import (
@ -73,15 +77,16 @@ from open_webui.socket.main import (
get_event_emitter,
)
from open_webui.internal.db import Session
from backend.open_webui.routers.webui import (
from open_webui.routers.webui import (
app as webui_app,
generate_function_chat_completion,
get_all_models as get_open_webui_models,
)
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.users import UserModel, Users
@ -523,6 +528,34 @@ app.state.sentence_transformer_rf = None
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
(
app.state.config.OPENAI_API_BASE_URL
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_BASE_URL
),
(
app.state.config.OPENAI_API_KEY
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else app.state.config.OLLAMA_API_KEY
),
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
update_embedding_model(
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
########################################
#
# IMAGES

View File

@ -155,17 +155,6 @@ def update_reranking_model(
request.app.state.sentence_transformer_rf = None
update_embedding_model(
request.app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
request.app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
##########################################
#
# API routes
@ -176,24 +165,6 @@ update_reranking_model(
router = APIRouter()
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
request.app.state.sentence_transformer_ef,
(
request.app.state.config.OPENAI_API_BASE_URL
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_BASE_URL
),
(
request.app.state.config.OPENAI_API_KEY
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
else request.app.state.config.OLLAMA_API_KEY
),
request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
)
class CollectionNameForm(BaseModel):
collection_name: Optional[str] = None