mirror of
https://github.com/open-webui/open-webui
synced 2025-06-15 19:05:04 +00:00
feat: external reranker
Co-Authored-By: Brendan Campbell <20541191+bcambs09@users.noreply.github.com>
This commit is contained in:
parent
34ec10a78c
commit
d5fd3b3600
@ -1965,6 +1965,12 @@ RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get(
|
|||||||
"RAG_EMBEDDING_PREFIX_FIELD_NAME", None
|
"RAG_EMBEDDING_PREFIX_FIELD_NAME", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAG_RERANKING_ENGINE = PersistentConfig(
|
||||||
|
"RAG_RERANKING_ENGINE",
|
||||||
|
"rag.reranking_engine",
|
||||||
|
os.environ.get("RAG_RERANKING_ENGINE", ""),
|
||||||
|
)
|
||||||
|
|
||||||
RAG_RERANKING_MODEL = PersistentConfig(
|
RAG_RERANKING_MODEL = PersistentConfig(
|
||||||
"RAG_RERANKING_MODEL",
|
"RAG_RERANKING_MODEL",
|
||||||
"rag.reranking_model",
|
"rag.reranking_model",
|
||||||
@ -1973,6 +1979,7 @@ RAG_RERANKING_MODEL = PersistentConfig(
|
|||||||
if RAG_RERANKING_MODEL.value != "":
|
if RAG_RERANKING_MODEL.value != "":
|
||||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
||||||
|
|
||||||
|
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||||
not OFFLINE_MODE
|
not OFFLINE_MODE
|
||||||
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||||
@ -1982,6 +1989,18 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
|||||||
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
|
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAG_EXTERNAL_RERANKER_URL = PersistentConfig(
|
||||||
|
"RAG_EXTERNAL_RERANKER_URL",
|
||||||
|
"rag.external_reranker_url",
|
||||||
|
os.environ.get("RAG_EXTERNAL_RERANKER_URL", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
RAG_EXTERNAL_RERANKER_API_KEY = PersistentConfig(
|
||||||
|
"RAG_EXTERNAL_RERANKER_API_KEY",
|
||||||
|
"rag.external_reranker_api_key",
|
||||||
|
os.environ.get("RAG_EXTERNAL_RERANKER_API_KEY", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RAG_TEXT_SPLITTER = PersistentConfig(
|
RAG_TEXT_SPLITTER = PersistentConfig(
|
||||||
"RAG_TEXT_SPLITTER",
|
"RAG_TEXT_SPLITTER",
|
||||||
|
@ -188,7 +188,10 @@ from open_webui.config import (
|
|||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
|
RAG_RERANKING_ENGINE,
|
||||||
RAG_RERANKING_MODEL,
|
RAG_RERANKING_MODEL,
|
||||||
|
RAG_EXTERNAL_RERANKER_URL,
|
||||||
|
RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_EMBEDDING_ENGINE,
|
RAG_EMBEDDING_ENGINE,
|
||||||
@ -655,7 +658,12 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
|||||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||||
|
|
||||||
|
app.state.config.RAG_RERANKING_ENGINE = RAG_RERANKING_ENGINE
|
||||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||||
|
app.state.config.RAG_EXTERNAL_RERANKER_URL = RAG_EXTERNAL_RERANKER_URL
|
||||||
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY = RAG_EXTERNAL_RERANKER_API_KEY
|
||||||
|
|
||||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
|
||||||
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||||
@ -736,7 +744,10 @@ try:
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.state.rf = get_rf(
|
app.state.rf = get_rf(
|
||||||
|
app.state.config.RAG_RERANKING_ENGINE,
|
||||||
app.state.config.RAG_RERANKING_MODEL,
|
app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
app.state.config.RAG_EXTERNAL_RERANKER_URL,
|
||||||
|
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
58
backend/open_webui/retrieval/models/external.py
Normal file
58
backend/open_webui/retrieval/models/external.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import logging
|
||||||
|
import requests
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalReranker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
url: str = "http://localhost:8080/v1/rerank",
|
||||||
|
model: str = "reranker",
|
||||||
|
):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.url = url
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]:
|
||||||
|
query = sentences[0][0]
|
||||||
|
docs = [i[1] for i in sentences]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"query": query,
|
||||||
|
"documents": docs,
|
||||||
|
"top_n": len(docs),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||||
|
log.info(f"ExternalReranker:predict:query {query}")
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
f"{self.url}",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
|
||||||
|
if "results" in data:
|
||||||
|
sorted_results = sorted(data["results"], key=lambda x: x["index"])
|
||||||
|
return [result["relevance_score"] for result in sorted_results]
|
||||||
|
else:
|
||||||
|
log.error("No results found in external reranking response")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Error in external reranking: {e}")
|
||||||
|
return None
|
@ -137,7 +137,10 @@ def get_ef(
|
|||||||
|
|
||||||
|
|
||||||
def get_rf(
|
def get_rf(
|
||||||
|
engine: str = "",
|
||||||
reranking_model: Optional[str] = None,
|
reranking_model: Optional[str] = None,
|
||||||
|
external_reranker_url: str = "",
|
||||||
|
external_reranker_api_key: str = "",
|
||||||
auto_update: bool = False,
|
auto_update: bool = False,
|
||||||
):
|
):
|
||||||
rf = None
|
rf = None
|
||||||
@ -155,19 +158,33 @@ def get_rf(
|
|||||||
log.error(f"ColBERT: {e}")
|
log.error(f"ColBERT: {e}")
|
||||||
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
||||||
else:
|
else:
|
||||||
import sentence_transformers
|
if engine == "external":
|
||||||
|
try:
|
||||||
|
from open_webui.retrieval.models.external import ExternalReranker
|
||||||
|
|
||||||
|
rf = ExternalReranker(
|
||||||
|
url=external_reranker_url,
|
||||||
|
api_key=external_reranker_api_key,
|
||||||
|
model=reranking_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"ExternalReranking: {e}")
|
||||||
|
raise Exception(ERROR_MESSAGES.DEFAULT(e))
|
||||||
|
else:
|
||||||
|
import sentence_transformers
|
||||||
|
|
||||||
|
try:
|
||||||
|
rf = sentence_transformers.CrossEncoder(
|
||||||
|
get_model_path(reranking_model, auto_update),
|
||||||
|
device=DEVICE_TYPE,
|
||||||
|
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
|
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
||||||
|
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"CrossEncoder: {e}")
|
||||||
|
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
||||||
|
|
||||||
try:
|
|
||||||
rf = sentence_transformers.CrossEncoder(
|
|
||||||
get_model_path(reranking_model, auto_update),
|
|
||||||
device=DEVICE_TYPE,
|
|
||||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
||||||
backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND,
|
|
||||||
model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"CrossEncoder: {e}")
|
|
||||||
raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
|
|
||||||
return rf
|
return rf
|
||||||
|
|
||||||
|
|
||||||
@ -225,14 +242,6 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/reranking")
|
|
||||||
async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
|
|
||||||
return {
|
|
||||||
"status": True,
|
|
||||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConfigForm(BaseModel):
|
class OpenAIConfigForm(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
key: str
|
key: str
|
||||||
@ -327,41 +336,6 @@ async def update_embedding_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RerankingModelUpdateForm(BaseModel):
|
|
||||||
reranking_model: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reranking/update")
|
|
||||||
async def update_reranking_config(
|
|
||||||
request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
|
|
||||||
):
|
|
||||||
log.info(
|
|
||||||
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
|
|
||||||
|
|
||||||
try:
|
|
||||||
request.app.state.rf = get_rf(
|
|
||||||
request.app.state.config.RAG_RERANKING_MODEL,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error loading reranking model: {e}")
|
|
||||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": True,
|
|
||||||
"reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Problem updating reranking model: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
@ -385,6 +359,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
|||||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||||
|
# Reranking settings
|
||||||
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
"RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
||||||
|
"RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
||||||
# Chunking settings
|
# Chunking settings
|
||||||
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
||||||
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||||||
@ -521,6 +500,12 @@ class ConfigForm(BaseModel):
|
|||||||
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None
|
||||||
MISTRAL_OCR_API_KEY: Optional[str] = None
|
MISTRAL_OCR_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
|
# Reranking settings
|
||||||
|
RAG_RERANKING_MODEL: Optional[str] = None
|
||||||
|
RAG_RERANKING_ENGINE: Optional[str] = None
|
||||||
|
RAG_EXTERNAL_RERANKING_URL: Optional[str] = None
|
||||||
|
RAG_EXTERNAL_RERANKING_API_KEY: Optional[str] = None
|
||||||
|
|
||||||
# Chunking settings
|
# Chunking settings
|
||||||
TEXT_SPLITTER: Optional[str] = None
|
TEXT_SPLITTER: Optional[str] = None
|
||||||
CHUNK_SIZE: Optional[int] = None
|
CHUNK_SIZE: Optional[int] = None
|
||||||
@ -632,6 +617,49 @@ async def update_rag_config(
|
|||||||
else request.app.state.config.MISTRAL_OCR_API_KEY
|
else request.app.state.config.MISTRAL_OCR_API_KEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reranking settings
|
||||||
|
request.app.state.config.RAG_RERANKING_ENGINE = (
|
||||||
|
form_data.RAG_RERANKING_ENGINE
|
||||||
|
if form_data.RAG_RERANKING_ENGINE is not None
|
||||||
|
else request.app.state.config.RAG_RERANKING_ENGINE
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.RAG_EXTERNAL_RERANKING_URL = (
|
||||||
|
form_data.RAG_EXTERNAL_RERANKING_URL
|
||||||
|
if form_data.RAG_EXTERNAL_RERANKING_URL is not None
|
||||||
|
else request.app.state.config.RAG_EXTERNAL_RERANKING_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY = (
|
||||||
|
form_data.RAG_EXTERNAL_RERANKING_API_KEY
|
||||||
|
if form_data.RAG_EXTERNAL_RERANKING_API_KEY is not None
|
||||||
|
else request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL
|
||||||
|
|
||||||
|
try:
|
||||||
|
request.app.state.rf = get_rf(
|
||||||
|
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
||||||
|
request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error loading reranking model: {e}")
|
||||||
|
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Problem updating reranking model: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||||
|
)
|
||||||
|
|
||||||
# Chunking settings
|
# Chunking settings
|
||||||
request.app.state.config.TEXT_SPLITTER = (
|
request.app.state.config.TEXT_SPLITTER = (
|
||||||
form_data.TEXT_SPLITTER
|
form_data.TEXT_SPLITTER
|
||||||
@ -788,6 +816,11 @@ async def update_rag_config(
|
|||||||
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
"DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||||
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
"DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||||
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
"MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY,
|
||||||
|
# Reranking settings
|
||||||
|
"RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL,
|
||||||
|
"RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE,
|
||||||
|
"RAG_EXTERNAL_RERANKING_URL": request.app.state.config.RAG_EXTERNAL_RERANKING_URL,
|
||||||
|
"RAG_EXTERNAL_RERANKING_API_KEY": request.app.state.config.RAG_EXTERNAL_RERANKING_API_KEY,
|
||||||
# Chunking settings
|
# Chunking settings
|
||||||
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
"TEXT_SPLITTER": request.app.state.config.TEXT_SPLITTER,
|
||||||
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
"CHUNK_SIZE": request.app.state.config.CHUNK_SIZE,
|
||||||
|
Loading…
Reference in New Issue
Block a user