mirror of
https://github.com/open-webui/open-webui
synced 2024-12-29 15:25:29 +00:00
wip
This commit is contained in:
parent
3bda1a8b88
commit
ccdf51588e
@ -39,6 +39,13 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||||||
from starlette.responses import Response, StreamingResponse
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.socket.main import (
|
||||||
|
app as socket_app,
|
||||||
|
periodic_usage_pool_cleanup,
|
||||||
|
get_event_call,
|
||||||
|
get_event_emitter,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.routers import (
|
from open_webui.routers import (
|
||||||
audio,
|
audio,
|
||||||
images,
|
images,
|
||||||
@ -63,35 +70,19 @@ from open_webui.routers import (
|
|||||||
users,
|
users,
|
||||||
utils,
|
utils,
|
||||||
)
|
)
|
||||||
from open_webui.retrieval.utils import get_sources_from_files
|
|
||||||
from open_webui.routers.retrieval import (
|
from open_webui.routers.retrieval import (
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
update_embedding_model,
|
get_ef,
|
||||||
update_reranking_model,
|
get_rf,
|
||||||
)
|
)
|
||||||
|
from open_webui.retrieval.utils import get_sources_from_files
|
||||||
|
|
||||||
|
|
||||||
from open_webui.socket.main import (
|
|
||||||
app as socket_app,
|
|
||||||
periodic_usage_pool_cleanup,
|
|
||||||
get_event_call,
|
|
||||||
get_event_emitter,
|
|
||||||
)
|
|
||||||
|
|
||||||
from open_webui.internal.db import Session
|
from open_webui.internal.db import Session
|
||||||
|
|
||||||
|
|
||||||
from open_webui.routers.webui import (
|
|
||||||
app as webui_app,
|
|
||||||
generate_function_chat_completion,
|
|
||||||
get_all_models as get_open_webui_models,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
from open_webui.models.users import UserModel, Users
|
from open_webui.models.users import UserModel, Users
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.constants import TASKS
|
from open_webui.constants import TASKS
|
||||||
@ -279,7 +270,7 @@ from open_webui.env import (
|
|||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from open_webui.utils.plugin import load_function_module_by_id
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
get_last_user_message,
|
get_last_user_message,
|
||||||
@ -528,8 +519,8 @@ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
|||||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||||
|
|
||||||
app.state.EMBEDDING_FUNCTION = None
|
app.state.EMBEDDING_FUNCTION = None
|
||||||
app.state.sentence_transformer_ef = None
|
app.state.ef = None
|
||||||
app.state.sentence_transformer_rf = None
|
app.state.rf = None
|
||||||
|
|
||||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||||
|
|
||||||
@ -537,29 +528,34 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
|
|||||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
app.state.config.RAG_EMBEDDING_ENGINE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
app.state.sentence_transformer_ef,
|
app.state.ef,
|
||||||
(
|
(
|
||||||
app.state.config.OPENAI_API_BASE_URL
|
app.state.config.RAG_OPENAI_API_BASE_URL
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else app.state.config.OLLAMA_BASE_URL
|
else app.state.config.RAG_OLLAMA_BASE_URL
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
app.state.config.OPENAI_API_KEY
|
app.state.config.RAG_OPENAI_API_KEY
|
||||||
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
else app.state.config.OLLAMA_API_KEY
|
else app.state.config.RAG_OLLAMA_API_KEY
|
||||||
),
|
),
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
update_embedding_model(
|
try:
|
||||||
app.state.config.RAG_EMBEDDING_MODEL,
|
app.state.ef = get_ef(
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
)
|
app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
|
)
|
||||||
|
|
||||||
update_reranking_model(
|
app.state.rf = get_rf(
|
||||||
app.state.config.RAG_RERANKING_MODEL,
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error updating models: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
@ -990,11 +986,11 @@ async def chat_completion_files_handler(
|
|||||||
sources = get_sources_from_files(
|
sources = get_sources_from_files(
|
||||||
files=files,
|
files=files,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
embedding_function=retrieval_app.state.EMBEDDING_FUNCTION,
|
embedding_function=app.state.EMBEDDING_FUNCTION,
|
||||||
k=retrieval_app.state.config.TOP_K,
|
k=app.state.config.TOP_K,
|
||||||
reranking_function=retrieval_app.state.sentence_transformer_rf,
|
reranking_function=app.state.rf,
|
||||||
r=retrieval_app.state.config.RELEVANCE_THRESHOLD,
|
r=app.state.config.RELEVANCE_THRESHOLD,
|
||||||
hybrid_search=retrieval_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(f"rag_contexts:sources: {sources}")
|
log.debug(f"rag_contexts:sources: {sources}")
|
||||||
|
@ -97,62 +97,58 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
|
|
||||||
def update_embedding_model(
|
def get_ef(
|
||||||
request: Request,
|
engine: str,
|
||||||
embedding_model: str,
|
embedding_model: str,
|
||||||
auto_update: bool = False,
|
auto_update: bool = False,
|
||||||
):
|
):
|
||||||
if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
ef = None
|
||||||
|
if embedding_model and engine == "":
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request.app.state.sentence_transformer_ef = SentenceTransformer(
|
ef = SentenceTransformer(
|
||||||
get_model_path(embedding_model, auto_update),
|
get_model_path(embedding_model, auto_update),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error loading SentenceTransformer: {e}")
|
log.debug(f"Error loading SentenceTransformer: {e}")
|
||||||
request.app.state.sentence_transformer_ef = None
|
|
||||||
else:
|
return ef
|
||||||
request.app.state.sentence_transformer_ef = None
|
|
||||||
|
|
||||||
|
|
||||||
def update_reranking_model(
|
def get_rf(
|
||||||
request: Request,
|
|
||||||
reranking_model: str,
|
reranking_model: str,
|
||||||
auto_update: bool = False,
|
auto_update: bool = False,
|
||||||
):
|
):
|
||||||
|
rf = None
|
||||||
if reranking_model:
|
if reranking_model:
|
||||||
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
|
||||||
try:
|
try:
|
||||||
from open_webui.retrieval.models.colbert import ColBERT
|
from open_webui.retrieval.models.colbert import ColBERT
|
||||||
|
|
||||||
request.app.state.sentence_transformer_rf = ColBERT(
|
rf = ColBERT(
|
||||||
get_model_path(reranking_model, auto_update),
|
get_model_path(reranking_model, auto_update),
|
||||||
env="docker" if DOCKER else None,
|
env="docker" if DOCKER else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"ColBERT: {e}")
|
log.error(f"ColBERT: {e}")
|
||||||
request.app.state.sentence_transformer_rf = None
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
||||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
||||||
else:
|
else:
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request.app.state.sentence_transformer_rf = (
|
rf = sentence_transformers.CrossEncoder(
|
||||||
sentence_transformers.CrossEncoder(
|
get_model_path(reranking_model, auto_update),
|
||||||
get_model_path(reranking_model, auto_update),
|
device=DEVICE_TYPE,
|
||||||
device=DEVICE_TYPE,
|
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
log.error("CrossEncoder error")
|
log.error("CrossEncoder error")
|
||||||
request.app.state.sentence_transformer_rf = None
|
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
||||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
return rf
|
||||||
else:
|
|
||||||
request.app.state.sentence_transformer_rf = None
|
|
||||||
|
|
||||||
|
|
||||||
##########################################
|
##########################################
|
||||||
@ -261,12 +257,15 @@ async def update_embedding_config(
|
|||||||
form_data.embedding_batch_size
|
form_data.embedding_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL)
|
request.app.state.ef = get_ef(
|
||||||
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
request.app.state.sentence_transformer_ef,
|
request.app.state.ef,
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_BASE_URL
|
request.app.state.config.OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
@ -316,7 +315,14 @@ async def update_reranking_config(
|
|||||||
try:
|
try:
|
||||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||||
|
|
||||||
update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True)
|
try:
|
||||||
|
request.app.state.rf = get_rf(
|
||||||
|
request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error loading reranking model: {e}")
|
||||||
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
@ -739,7 +745,7 @@ def save_docs_to_vector_db(
|
|||||||
embedding_function = get_embedding_function(
|
embedding_function = get_embedding_function(
|
||||||
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
request.app.state.config.RAG_EMBEDDING_ENGINE,
|
||||||
request.app.state.config.RAG_EMBEDDING_MODEL,
|
request.app.state.config.RAG_EMBEDDING_MODEL,
|
||||||
request.app.state.sentence_transformer_ef,
|
request.app.state.ef,
|
||||||
(
|
(
|
||||||
request.app.state.config.OPENAI_API_BASE_URL
|
request.app.state.config.OPENAI_API_BASE_URL
|
||||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
|
||||||
@ -1286,7 +1292,7 @@ def query_doc_handler(
|
|||||||
query=form_data.query,
|
query=form_data.query,
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.sentence_transformer_rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
form_data.r
|
form_data.r
|
||||||
if form_data.r
|
if form_data.r
|
||||||
@ -1328,7 +1334,7 @@ def query_collection_handler(
|
|||||||
queries=[form_data.query],
|
queries=[form_data.query],
|
||||||
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
embedding_function=request.app.state.EMBEDDING_FUNCTION,
|
||||||
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
|
||||||
reranking_function=request.app.state.sentence_transformer_rf,
|
reranking_function=request.app.state.rf,
|
||||||
r=(
|
r=(
|
||||||
form_data.r
|
form_data.r
|
||||||
if form_data.r
|
if form_data.r
|
||||||
|
Loading…
Reference in New Issue
Block a user