From d5fd3b36006b4073af2ce0c04171d0a3034b57d7 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 10 May 2025 18:25:20 +0400 Subject: [PATCH] feat: external reranker Co-Authored-By: Brendan Campbell <20541191+bcambs09@users.noreply.github.com> --- backend/open_webui/config.py | 19 +++ backend/open_webui/main.py | 11 ++ .../open_webui/retrieval/models/external.py | 58 +++++++ backend/open_webui/routers/retrieval.py | 143 +++++++++++------- 4 files changed, 176 insertions(+), 55 deletions(-) create mode 100644 backend/open_webui/retrieval/models/external.py diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 4ac476285..38bd709f1 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1965,6 +1965,12 @@ RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get( "RAG_EMBEDDING_PREFIX_FIELD_NAME", None ) +RAG_RERANKING_ENGINE = PersistentConfig( + "RAG_RERANKING_ENGINE", + "rag.reranking_engine", + os.environ.get("RAG_RERANKING_ENGINE", ""), +) + RAG_RERANKING_MODEL = PersistentConfig( "RAG_RERANKING_MODEL", "rag.reranking_model", @@ -1973,6 +1979,7 @@ RAG_RERANKING_MODEL = PersistentConfig( if RAG_RERANKING_MODEL.value != "": log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") + RAG_RERANKING_MODEL_AUTO_UPDATE = ( not OFFLINE_MODE and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true" @@ -1982,6 +1989,18 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true" ) +RAG_EXTERNAL_RERANKER_URL = PersistentConfig( + "RAG_EXTERNAL_RERANKER_URL", + "rag.external_reranker_url", + os.environ.get("RAG_EXTERNAL_RERANKER_URL", ""), +) + +RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig( + "RAG_EXTERNAL_RERANKER_API_KEY", + "rag.external_reranker_api_key", + os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""), +) + RAG_TEXT_SPLITTER = PersistentConfig( "RAG_TEXT_SPLITTER", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index becacf4dd..e5fdace6d 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -188,7 +188,10 @@ from open_webui.config import ( RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_ENGINE, RAG_RERANKING_MODEL, + RAG_EXTERNAL_RERANKER_URL, + RAG_EXTERNAL_RERANKER_API_KEY, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_ENGINE, @@ -655,7 +658,12 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE + +app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL +app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY + app.state.config.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL @@ -736,7 +744,10 @@ try: ) app.state.rf = get_rf( + app.state.config.RAG_RERANKING_ENGINE, app.state.config.RAG_RERANKING_MODEL, + app.state.config.RAG_EXTERNAL_RERANKER_URL, + app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, RAG_RERANKING_MODEL_AUTO_UPDATE, ) except Exception as e: diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py new file mode 100644 index 000000000..187d66e38 --- /dev/null +++ b/backend/open_webui/retrieval/models/external.py @@ -0,0 +1,58 @@ +import logging +import requests +from typing import Optional, List, Tuple + +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class ExternalReranker: + def __init__( + self, + api_key: str, + url: str = "http://localhost:8080/v1/rerank", + model: str = "reranker", + ): + self.api_key = api_key + self.url = url + self.model = model + + def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + query = sentences[0][0] + docs = [i[1] for i in sentences] + + payload = { + "model": self.model, + "query": query, + "documents": docs, + "top_n": len(docs), + } + + try: + log.info(f"ExternalReranker:predict:model {self.model}") + log.info(f"ExternalReranker:predict:query {query}") + + r = requests.post( + f"{self.url}", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + json=payload, + ) + + r.raise_for_status() + data = r.json() + + if "results" in data: + sorted_results = sorted(data["results"], key=lambda x: x["index"]) + return [result["relevance_score"] for result in sorted_results] + else: + log.error("No results found in external reranking response") + return None + + except Exception as e: + log.exception(f"Error in external reranking: {e}") + return None diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index b86d8968d..bdeeb3136 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -137,7 +137,10 @@ def get_ef( def get_rf( + engine: str = "", reranking_model: Optional[str] = None, + external_reranker_url: str = "", + external_reranker_api_key: str = "", auto_update: bool = False, ): rf = None @@ -155,19 +158,33 @@ def get_rf( log.error(f"ColBERT: {e}") raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: - import sentence_transformers + if engine == "external": + try: + from open_webui.retrieval.models.external import ExternalReranker + + rf = ExternalReranker( + url=external_reranker_url, + api_key=external_reranker_api_key, + model=reranking_model, + ) + except Exception as e: + log.error(f"ExternalReranking: {e}") + raise Exception(ERROR_MESSAGES.DEFAULT(e)) + else: + import sentence_transformers + + try: + rf = sentence_transformers.CrossEncoder( + get_model_path(reranking_model, auto_update), + device=DEVICE_TYPE, + trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, + model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + ) + except Exception as e: + log.error(f"CrossEncoder: {e}") + raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) - try: - rf = sentence_transformers.CrossEncoder( - get_model_path(reranking_model, auto_update), - device=DEVICE_TYPE, - trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, - model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, - ) - except Exception as e: - log.error(f"CrossEncoder: {e}") - raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) return rf @@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): } -@router.get("/reranking") -async def get_reraanking_config(request: Request, user=Depends(get_admin_user)): - return { - "status": True, - "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, - } - - class OpenAIConfigForm(BaseModel): url: str key: str @@ -327,41 +336,6 @@ async def update_embedding_config( ) -class RerankingModelUpdateForm(BaseModel): - reranking_model: str - - -@router.post("/reranking/update") -async def update_reranking_config( - request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) -): - log.info( - f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}" - ) - try: - request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model - - try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_MODEL, - True, - ) - except Exception as e: - log.error(f"Error loading reranking model: {e}") - request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False - - return { - "status": True, - "reranking_model": request.app.state.config.RAG_RERANKING_MODEL, - } - except Exception as e: - log.exception(f"Problem updating reranking model: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - @router.get("/config") async def get_rag_config(request: Request, user=Depends(get_admin_user)): return { @@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # Reranking settings + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, + "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL, + "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY, # Chunking settings "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE, @@ -521,6 +500,12 @@ class ConfigForm(BaseModel): DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None MISTRAL_OCR_API_KEY: Optional[str] = None + # Reranking settings + RAG_RERANKING_MODEL: Optional[str] = None + RAG_RERANKING_ENGINE: Optional[str] = None + RAG_EXTERNAL_RERANKING_URL: Optional[str] = None + RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None + # Chunking settings TEXT_SPLITTER: Optional[str] = None CHUNK_SIZE: Optional[int] = None @@ -632,6 +617,49 @@ async def update_rag_config( else request.app.state.config.MISTRAL_OCR_API_KEY ) + # Reranking settings + request.app.state.config.RAG_RERANKING_ENGINE = ( + form_data.RAG_RERANKING_ENGINE + if form_data.RAG_RERANKING_ENGINE is not None + else request.app.state.config.RAG_RERANKING_ENGINE + ) + + request.app.state.config.RAG_EXTERNAL_RERANKING_URL = ( + form_data.RAG_EXTERNAL_RERANKING_URL + if form_data.RAG_EXTERNAL_RERANKING_URL is not None + else request.app.state.config.RAG_EXTERNAL_RERANKING_URL + ) + + request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = ( + form_data.RAG_EXTERNAL_RERANKING_API_KEY + if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None + else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY + ) + + log.info( + f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" + ) + try: + request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL + + try: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKING_URL, + request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY, + True, + ) + except Exception as e: + log.error(f"Error loading reranking model: {e}") + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False + except Exception as e: + log.exception(f"Problem updating reranking model: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + # Chunking settings request.app.state.config.TEXT_SPLITTER = ( form_data.TEXT_SPLITTER @@ -788,6 +816,11 @@ async def update_rag_config( "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # Reranking settings + "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, + "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, + "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL, + "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY, # Chunking settings "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER, "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,