mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	feat: external reranker
Co-Authored-By: Brendan Campbell <20541191+bcambs09@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									34ec10a78c
								
							
						
					
					
						commit
						d5fd3b3600
					
				@ -1965,6 +1965,12 @@ RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get(
 | 
			
		||||
    "RAG_EMBEDDING_PREFIX_FIELD_NAME", None
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_RERANKING_ENGINE = PersistentConfig(
 | 
			
		||||
    "RAG_RERANKING_ENGINE",
 | 
			
		||||
    "rag.reranking_engine",
 | 
			
		||||
    os.environ.get("RAG_RERANKING_ENGINE", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_RERANKING_MODEL = PersistentConfig(
 | 
			
		||||
    "RAG_RERANKING_MODEL",
 | 
			
		||||
    "rag.reranking_model",
 | 
			
		||||
@ -1973,6 +1979,7 @@ RAG_RERANKING_MODEL = PersistentConfig(
 | 
			
		||||
if RAG_RERANKING_MODEL.value != "":
 | 
			
		||||
    log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
 | 
			
		||||
    not OFFLINE_MODE
 | 
			
		||||
    and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
 | 
			
		||||
@ -1982,6 +1989,18 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
 | 
			
		||||
    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_EXTERNAL_RERANKER_URL = PersistentConfig(
 | 
			
		||||
    "RAG_EXTERNAL_RERANKER_URL",
 | 
			
		||||
    "rag.external_reranker_url",
 | 
			
		||||
    os.environ.get("RAG_EXTERNAL_RERANKER_URL", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig(
 | 
			
		||||
    "RAG_EXTERNAL_RERANKER_API_KEY",
 | 
			
		||||
    "rag.external_reranker_api_key",
 | 
			
		||||
    os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
RAG_TEXT_SPLITTER = PersistentConfig(
 | 
			
		||||
    "RAG_TEXT_SPLITTER",
 | 
			
		||||
 | 
			
		||||
@ -188,7 +188,10 @@ from open_webui.config import (
 | 
			
		||||
    RAG_EMBEDDING_MODEL,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_AUTO_UPDATE,
 | 
			
		||||
    RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    RAG_RERANKING_ENGINE,
 | 
			
		||||
    RAG_RERANKING_MODEL,
 | 
			
		||||
    RAG_EXTERNAL_RERANKER_URL,
 | 
			
		||||
    RAG_EXTERNAL_RERANKER_API_KEY,
 | 
			
		||||
    RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
			
		||||
    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
    RAG_EMBEDDING_ENGINE,
 | 
			
		||||
@ -655,7 +658,12 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
 | 
			
		||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 | 
			
		||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 | 
			
		||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
 | 
			
		||||
 | 
			
		||||
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
 | 
			
		||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 | 
			
		||||
app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
 | 
			
		||||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
 | 
			
		||||
 | 
			
		||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 | 
			
		||||
 | 
			
		||||
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 | 
			
		||||
@ -736,7 +744,10 @@ try:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    app.state.rf = get_rf(
 | 
			
		||||
        app.state.config.RAG_RERANKING_ENGINE,
 | 
			
		||||
        app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
        app.state.config.RAG_EXTERNAL_RERANKER_URL,
 | 
			
		||||
        app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
 | 
			
		||||
        RAG_RERANKING_MODEL_AUTO_UPDATE,
 | 
			
		||||
    )
 | 
			
		||||
except Exception as e:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										58
									
								
								backend/open_webui/retrieval/models/external.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								backend/open_webui/retrieval/models/external.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,58 @@
 | 
			
		||||
import logging
 | 
			
		||||
import requests
 | 
			
		||||
from typing import Optional, List, Tuple
 | 
			
		||||
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExternalReranker:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        api_key: str,
 | 
			
		||||
        url: str = "http://localhost:8080/v1/rerank",
 | 
			
		||||
        model: str = "reranker",
 | 
			
		||||
    ):
 | 
			
		||||
        self.api_key = api_key
 | 
			
		||||
        self.url = url
 | 
			
		||||
        self.model = model
 | 
			
		||||
 | 
			
		||||
    def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
 | 
			
		||||
        query = sentences[0][0]
 | 
			
		||||
        docs = [i[1] for i in sentences]
 | 
			
		||||
 | 
			
		||||
        payload = {
 | 
			
		||||
            "model": self.model,
 | 
			
		||||
            "query": query,
 | 
			
		||||
            "documents": docs,
 | 
			
		||||
            "top_n": len(docs),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            log.info(f"ExternalReranker:predict:model {self.model}")
 | 
			
		||||
            log.info(f"ExternalReranker:predict:query {query}")
 | 
			
		||||
 | 
			
		||||
            r = requests.post(
 | 
			
		||||
                f"{self.url}",
 | 
			
		||||
                headers={
 | 
			
		||||
                    "Content-Type": "application/json",
 | 
			
		||||
                    "Authorization": f"Bearer {self.api_key}",
 | 
			
		||||
                },
 | 
			
		||||
                json=payload,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
            data = r.json()
 | 
			
		||||
 | 
			
		||||
            if "results" in data:
 | 
			
		||||
                sorted_results = sorted(data["results"], key=lambda x: x["index"])
 | 
			
		||||
                return [result["relevance_score"] for result in sorted_results]
 | 
			
		||||
            else:
 | 
			
		||||
                log.error("No results found in external reranking response")
 | 
			
		||||
                return None
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.exception(f"Error in external reranking: {e}")
 | 
			
		||||
            return None
 | 
			
		||||
@ -137,7 +137,10 @@ def get_ef(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_rf(
 | 
			
		||||
    engine: str = "",
 | 
			
		||||
    reranking_model: Optional[str] = None,
 | 
			
		||||
    external_reranker_url: str = "",
 | 
			
		||||
    external_reranker_api_key: str = "",
 | 
			
		||||
    auto_update: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    rf = None
 | 
			
		||||
@ -155,19 +158,33 @@ def get_rf(
 | 
			
		||||
                log.error(f"ColBERT: {e}")
 | 
			
		||||
                raise Exception(ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
        else:
 | 
			
		||||
            import sentence_transformers
 | 
			
		||||
            if engine == "external":
 | 
			
		||||
                try:
 | 
			
		||||
                    from open_webui.retrieval.models.external import ExternalReranker
 | 
			
		||||
 | 
			
		||||
                    rf = ExternalReranker(
 | 
			
		||||
                        url=external_reranker_url,
 | 
			
		||||
                        api_key=external_reranker_api_key,
 | 
			
		||||
                        model=reranking_model,
 | 
			
		||||
                    )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    log.error(f"ExternalReranking: {e}")
 | 
			
		||||
                    raise Exception(ERROR_MESSAGES.DEFAULT(e))
 | 
			
		||||
            else:
 | 
			
		||||
                import sentence_transformers
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    rf = sentence_transformers.CrossEncoder(
 | 
			
		||||
                        get_model_path(reranking_model, auto_update),
 | 
			
		||||
                        device=DEVICE_TYPE,
 | 
			
		||||
                        trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
                        backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
 | 
			
		||||
                        model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
 | 
			
		||||
                    )
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    log.error(f"CrossEncoder: {e}")
 | 
			
		||||
                    raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                rf = sentence_transformers.CrossEncoder(
 | 
			
		||||
                    get_model_path(reranking_model, auto_update),
 | 
			
		||||
                    device=DEVICE_TYPE,
 | 
			
		||||
                    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
 | 
			
		||||
                    backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
 | 
			
		||||
                    model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
 | 
			
		||||
                )
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                log.error(f"CrossEncoder: {e}")
 | 
			
		||||
                raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
 | 
			
		||||
    return rf
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.get("/reranking")
 | 
			
		||||
async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
        "status": True,
 | 
			
		||||
        "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OpenAIConfigForm(BaseModel):
 | 
			
		||||
    url: str
 | 
			
		||||
    key: str
 | 
			
		||||
@ -327,41 +336,6 @@ async def update_embedding_config(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RerankingModelUpdateForm(BaseModel):
 | 
			
		||||
    reranking_model: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/reranking/update")
 | 
			
		||||
async def update_reranking_config(
 | 
			
		||||
    request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
 | 
			
		||||
):
 | 
			
		||||
    log.info(
 | 
			
		||||
        f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
        request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            request.app.state.rf = get_rf(
 | 
			
		||||
                request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
                True,
 | 
			
		||||
            )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.error(f"Error loading reranking model: {e}")
 | 
			
		||||
            request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            "status": True,
 | 
			
		||||
            "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
        }
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(f"Problem updating reranking model: {e}")
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
			
		||||
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.get("/config")
 | 
			
		||||
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
    return {
 | 
			
		||||
@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
 | 
			
		||||
        "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
 | 
			
		||||
        "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
 | 
			
		||||
        "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
 | 
			
		||||
        # Reranking settings
 | 
			
		||||
        "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
        "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
 | 
			
		||||
        "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
 | 
			
		||||
        "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
 | 
			
		||||
        # Chunking settings
 | 
			
		||||
        "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
 | 
			
		||||
        "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
 | 
			
		||||
@ -521,6 +500,12 @@ class ConfigForm(BaseModel):
 | 
			
		||||
    DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
 | 
			
		||||
    MISTRAL_OCR_API_KEY: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # Reranking settings
 | 
			
		||||
    RAG_RERANKING_MODEL: Optional[str] = None
 | 
			
		||||
    RAG_RERANKING_ENGINE: Optional[str] = None
 | 
			
		||||
    RAG_EXTERNAL_RERANKING_URL: Optional[str] = None
 | 
			
		||||
    RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    # Chunking settings
 | 
			
		||||
    TEXT_SPLITTER: Optional[str] = None
 | 
			
		||||
    CHUNK_SIZE: Optional[int] = None
 | 
			
		||||
@ -632,6 +617,49 @@ async def update_rag_config(
 | 
			
		||||
        else request.app.state.config.MISTRAL_OCR_API_KEY
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Reranking settings
 | 
			
		||||
    request.app.state.config.RAG_RERANKING_ENGINE = (
 | 
			
		||||
        form_data.RAG_RERANKING_ENGINE
 | 
			
		||||
        if form_data.RAG_RERANKING_ENGINE is not None
 | 
			
		||||
        else request.app.state.config.RAG_RERANKING_ENGINE
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    request.app.state.config.RAG_EXTERNAL_RERANKING_URL = (
 | 
			
		||||
        form_data.RAG_EXTERNAL_RERANKING_URL
 | 
			
		||||
        if form_data.RAG_EXTERNAL_RERANKING_URL is not None
 | 
			
		||||
        else request.app.state.config.RAG_EXTERNAL_RERANKING_URL
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = (
 | 
			
		||||
        form_data.RAG_EXTERNAL_RERANKING_API_KEY
 | 
			
		||||
        if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None
 | 
			
		||||
        else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    log.info(
 | 
			
		||||
        f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
        request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            request.app.state.rf = get_rf(
 | 
			
		||||
                request.app.state.config.RAG_RERANKING_ENGINE,
 | 
			
		||||
                request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
                request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
 | 
			
		||||
                request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
 | 
			
		||||
                True,
 | 
			
		||||
            )
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            log.error(f"Error loading reranking model: {e}")
 | 
			
		||||
            request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(f"Problem updating reranking model: {e}")
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 | 
			
		||||
            detail=ERROR_MESSAGES.DEFAULT(e),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Chunking settings
 | 
			
		||||
    request.app.state.config.TEXT_SPLITTER = (
 | 
			
		||||
        form_data.TEXT_SPLITTER
 | 
			
		||||
@ -788,6 +816,11 @@ async def update_rag_config(
 | 
			
		||||
        "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
 | 
			
		||||
        "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
 | 
			
		||||
        "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
 | 
			
		||||
        # Reranking settings
 | 
			
		||||
        "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
 | 
			
		||||
        "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
 | 
			
		||||
        "RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
 | 
			
		||||
        "RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
 | 
			
		||||
        # Chunking settings
 | 
			
		||||
        "TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
 | 
			
		||||
        "CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user