Merge pull request #11497 from mahenning/k_reranker

feat: Added new k_reranker parameter
This commit is contained in:
Timothy Jaeryang Baek 2025-03-26 20:50:43 -07:00 committed by GitHub
commit eaba9d8a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 57 additions and 5 deletions

View File

@ -1685,6 +1685,11 @@ BYPASS_EMBEDDING_AND_RETRIEVAL = PersistentConfig(
RAG_TOP_K = PersistentConfig( RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
) )
RAG_TOP_K_RERANKER = PersistentConfig(
"RAG_TOP_K_RERANKER",
"rag.top_k_reranker",
int(os.environ.get("RAG_TOP_K_RERANKER", "3"))
)
RAG_RELEVANCE_THRESHOLD = PersistentConfig( RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"RAG_RELEVANCE_THRESHOLD", "RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold", "rag.relevance_threshold",

View File

@ -191,6 +191,7 @@ from open_webui.config import (
DOCUMENT_INTELLIGENCE_ENDPOINT, DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY, DOCUMENT_INTELLIGENCE_KEY,
RAG_TOP_K, RAG_TOP_K,
RAG_TOP_K_RERANKER,
RAG_TEXT_SPLITTER, RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME, TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES,
@ -552,6 +553,7 @@ app.state.FUNCTIONS = {}
app.state.config.TOP_K = RAG_TOP_K app.state.config.TOP_K = RAG_TOP_K
app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT

View File

@ -106,6 +106,7 @@ def query_doc_with_hybrid_search(
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
k_reranker: int,
r: float, r: float,
) -> dict: ) -> dict:
try: try:
@ -128,7 +129,7 @@ def query_doc_with_hybrid_search(
) )
compressor = RerankCompressor( compressor = RerankCompressor(
embedding_function=embedding_function, embedding_function=embedding_function,
top_n=k, top_n=k_reranker,
reranking_function=reranking_function, reranking_function=reranking_function,
r_score=r, r_score=r,
) )
@ -138,10 +139,20 @@ def query_doc_with_hybrid_search(
) )
result = compression_retriever.invoke(query) result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
sorted_items = sorted(zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
result = { result = {
"distances": [[d.metadata.get("score") for d in result]], "distances": [distances],
"documents": [[d.page_content for d in result]], "documents": [documents],
"metadatas": [[d.metadata for d in result]], "metadatas": [metadatas],
} }
log.info( log.info(
@ -281,6 +292,7 @@ def query_collection_with_hybrid_search(
embedding_function, embedding_function,
k: int, k: int,
reranking_function, reranking_function,
k_reranker: int,
r: float, r: float,
) -> dict: ) -> dict:
results = [] results = []
@ -294,6 +306,7 @@ def query_collection_with_hybrid_search(
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker,
r=r, r=r,
) )
results.append(result) results.append(result)
@ -354,6 +367,7 @@ def get_sources_from_files(
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
k_reranker,
r, r,
hybrid_search, hybrid_search,
full_context=False, full_context=False,
@ -470,6 +484,7 @@ def get_sources_from_files(
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
reranking_function=reranking_function, reranking_function=reranking_function,
k_reranker=k_reranker,
r=r, r=r,
) )
except Exception as e: except Exception as e:

View File

@ -719,6 +719,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
"status": True, "status": True,
"template": request.app.state.config.RAG_TEMPLATE, "template": request.app.state.config.RAG_TEMPLATE,
"k": request.app.state.config.TOP_K, "k": request.app.state.config.TOP_K,
"k_reranker": request.app.state.config.TOP_K_RERANKER,
"r": request.app.state.config.RELEVANCE_THRESHOLD, "r": request.app.state.config.RELEVANCE_THRESHOLD,
"hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
@ -726,6 +727,7 @@ async def get_query_settings(request: Request, user=Depends(get_admin_user)):
class QuerySettingsForm(BaseModel): class QuerySettingsForm(BaseModel):
k: Optional[int] = None k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None r: Optional[float] = None
template: Optional[str] = None template: Optional[str] = None
hybrid: Optional[bool] = None hybrid: Optional[bool] = None
@ -737,6 +739,7 @@ async def update_query_settings(
): ):
request.app.state.config.RAG_TEMPLATE = form_data.template request.app.state.config.RAG_TEMPLATE = form_data.template
request.app.state.config.TOP_K = form_data.k if form_data.k else 4 request.app.state.config.TOP_K = form_data.k if form_data.k else 4
request.app.state.config.TOP_K_RERANKER = form_data.k_reranker or 4
request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = ( request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
@ -747,6 +750,7 @@ async def update_query_settings(
"status": True, "status": True,
"template": request.app.state.config.RAG_TEMPLATE, "template": request.app.state.config.RAG_TEMPLATE,
"k": request.app.state.config.TOP_K, "k": request.app.state.config.TOP_K,
"k_reranker": request.app.state.config.TOP_K_RERANKER,
"r": request.app.state.config.RELEVANCE_THRESHOLD, "r": request.app.state.config.RELEVANCE_THRESHOLD,
"hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
@ -1495,6 +1499,7 @@ class QueryDocForm(BaseModel):
collection_name: str collection_name: str
query: str query: str
k: Optional[int] = None k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None r: Optional[float] = None
hybrid: Optional[bool] = None hybrid: Optional[bool] = None
@ -1515,6 +1520,7 @@ def query_doc_handler(
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
r=( r=(
form_data.r form_data.r
if form_data.r if form_data.r
@ -1543,6 +1549,7 @@ class QueryCollectionsForm(BaseModel):
collection_names: list[str] collection_names: list[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = None
k_reranker: Optional[int] = None
r: Optional[float] = None r: Optional[float] = None
hybrid: Optional[bool] = None hybrid: Optional[bool] = None
@ -1563,6 +1570,7 @@ def query_collection_handler(
), ),
k=form_data.k if form_data.k else request.app.state.config.TOP_K, k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER,
r=( r=(
form_data.r form_data.r
if form_data.r if form_data.r

View File

@ -584,6 +584,7 @@ async def chat_completion_files_handler(
), ),
k=request.app.state.config.TOP_K, k=request.app.state.config.TOP_K,
reranking_function=request.app.state.rf, reranking_function=request.app.state.rf,
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD, r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=request.app.state.config.RAG_FULL_CONTEXT, full_context=request.app.state.config.RAG_FULL_CONTEXT,

View File

@ -76,6 +76,7 @@
template: '', template: '',
r: 0.0, r: 0.0,
k: 4, k: 4,
k_reranker: 4,
hybrid: false hybrid: false
}; };
@ -765,6 +766,23 @@
</div> </div>
</div> </div>
{#if querySettings.hybrid === true}
<div class="mb-2.5 flex w-full justify-between">
<div class="self-center text-xs font-medium">{$i18n.t('Top K Reranker')}</div>
<div class="flex items-center relative">
<input
class="flex-1 w-full rounded-lg text-sm bg-transparent outline-hidden"
type="number"
placeholder={$i18n.t('Enter Top K Reranker')}
bind:value={querySettings.k_reranker}
autocomplete="off"
min="0"
/>
</div>
</div>
{/if}
{#if querySettings.hybrid === true} {#if querySettings.hybrid === true}
<div class=" mb-2.5 flex flex-col w-full justify-between"> <div class=" mb-2.5 flex flex-col w-full justify-between">
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between">

View File

@ -437,6 +437,7 @@
"Enter timeout in seconds": "Geben Sie den Timeout in Sekunden ein", "Enter timeout in seconds": "Geben Sie den Timeout in Sekunden ein",
"Enter to Send": "'Enter' zum Senden", "Enter to Send": "'Enter' zum Senden",
"Enter Top K": "Geben Sie Top K ein", "Enter Top K": "Geben Sie Top K ein",
"Enter Top K Reranker": "Geben Sie Top K für Reranker ein",
"Enter URL (e.g. http://127.0.0.1:7860/)": "Geben Sie die URL ein (z. B. http://127.0.0.1:7860/)", "Enter URL (e.g. http://127.0.0.1:7860/)": "Geben Sie die URL ein (z. B. http://127.0.0.1:7860/)",
"Enter URL (e.g. http://localhost:11434)": "Geben Sie die URL ein (z. B. http://localhost:11434)", "Enter URL (e.g. http://localhost:11434)": "Geben Sie die URL ein (z. B. http://localhost:11434)",
"Enter your current password": "Geben Sie Ihr aktuelles Passwort ein", "Enter your current password": "Geben Sie Ihr aktuelles Passwort ein",

View File

@ -437,6 +437,7 @@
"Enter timeout in seconds": "", "Enter timeout in seconds": "",
"Enter to Send": "", "Enter to Send": "",
"Enter Top K": "", "Enter Top K": "",
"Enter Top K Reranker": "",
"Enter URL (e.g. http://127.0.0.1:7860/)": "", "Enter URL (e.g. http://127.0.0.1:7860/)": "",
"Enter URL (e.g. http://localhost:11434)": "", "Enter URL (e.g. http://localhost:11434)": "",
"Enter your current password": "", "Enter your current password": "",

View File

@ -437,6 +437,7 @@
"Enter timeout in seconds": "", "Enter timeout in seconds": "",
"Enter to Send": "", "Enter to Send": "",
"Enter Top K": "", "Enter Top K": "",
"Enter Top K Reranker": "",
"Enter URL (e.g. http://127.0.0.1:7860/)": "", "Enter URL (e.g. http://127.0.0.1:7860/)": "",
"Enter URL (e.g. http://localhost:11434)": "", "Enter URL (e.g. http://localhost:11434)": "",
"Enter your current password": "", "Enter your current password": "",