feat: external reranker

Co-Authored-By: Brendan Campbell <20541191+bcambs09@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek
2025-05-10 18:25:20 +04:00
parent 34ec10a78c
commit d5fd3b3600
4 changed files with 176 additions and 55 deletions

View File

@@ -137,7 +137,10 @@ def get_ef(
def get_rf(
engine: str = "",
reranking_model: Optional[str] = None,
external_reranker_url: str = "",
external_reranker_api_key: str = "",
auto_update: bool = False,
):
rf = None
@@ -155,19 +158,33 @@ def get_rf(
log.error(f"ColBERT: {e}")
raise Exception(ERROR_MESSAGES.DEFAULT(e))
else:
import sentence_transformers
if engine == "external":
try:
from open_webui.retrieval.models.external import ExternalReranker
rf = ExternalReranker(
url=external_reranker_url,
api_key=external_reranker_api_key,
model=reranking_model,
)
except Exception as e:
log.error(f"ExternalReranking: {e}")
raise Exception(ERROR_MESSAGES.DEFAULT(e))
else:
import sentence_transformers
try:
rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
except Exception as e:
log.error(f"CrossEncoder: {e}")
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
try:
rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, auto_update),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
)
except Exception as e:
log.error(f"CrossEncoder: {e}")
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
return rf
@@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
}
@router.get("/reranking")
async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
return {
"status": True,
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
}
class OpenAIConfigForm(BaseModel):
url: str
key: str
@@ -327,41 +336,6 @@ async def update_embedding_config(
)
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@router.post("/reranking/update")
async def update_reranking_config(
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_MODEL,
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
return {
"status": True,
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@router.get("/config")
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
return {
@@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
# Reranking settings
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
"RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
"RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
# Chunking settings
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
@@ -521,6 +500,12 @@ class ConfigForm(BaseModel):
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
MISTRAL_OCR_API_KEY: Optional[str] = None
# Reranking settings
RAG_RERANKING_MODEL: Optional[str] = None
RAG_RERANKING_ENGINE: Optional[str] = None
RAG_EXTERNAL_RERANKING_URL: Optional[str] = None
RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None
# Chunking settings
TEXT_SPLITTER: Optional[str] = None
CHUNK_SIZE: Optional[int] = None
@@ -632,6 +617,49 @@ async def update_rag_config(
else request.app.state.config.MISTRAL_OCR_API_KEY
)
# Reranking settings
request.app.state.config.RAG_RERANKING_ENGINE = (
form_data.RAG_RERANKING_ENGINE
if form_data.RAG_RERANKING_ENGINE is not None
else request.app.state.config.RAG_RERANKING_ENGINE
)
request.app.state.config.RAG_EXTERNAL_RERANKING_URL = (
form_data.RAG_EXTERNAL_RERANKING_URL
if form_data.RAG_EXTERNAL_RERANKING_URL is not None
else request.app.state.config.RAG_EXTERNAL_RERANKING_URL
)
request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = (
form_data.RAG_EXTERNAL_RERANKING_API_KEY
if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None
else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY
)
log.info(
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
)
try:
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
True,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
# Chunking settings
request.app.state.config.TEXT_SPLITTER = (
form_data.TEXT_SPLITTER
@@ -788,6 +816,11 @@ async def update_rag_config(
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
# Reranking settings
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
"RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
"RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
# Chunking settings
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,