fix/refac: hybrid search

This commit is contained in:
Timothy Jaeryang Baek 2025-03-30 20:48:22 -07:00
parent ce0d82b55f
commit 50b8dec3ac

View File

@ -16,6 +16,8 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.models.files import Files from open_webui.models.files import Files
from open_webui.retrieval.vector.main import GetResult
from open_webui.env import ( from open_webui.env import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
OFFLINE_MODE, OFFLINE_MODE,
@ -98,7 +100,7 @@ def get_doc(collection_name: str, user: UserModel = None):
def query_doc_with_hybrid_search( def query_doc_with_hybrid_search(
collection_name: str, collection_name: str,
collection_data, collection_result: GetResult,
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
@ -108,8 +110,8 @@ def query_doc_with_hybrid_search(
) -> dict: ) -> dict:
try: try:
bm25_retriever = BM25Retriever.from_texts( bm25_retriever = BM25Retriever.from_texts(
texts=collection_data.documents[0], texts=collection_result.documents[0],
metadatas=collection_data.metadatas[0], metadatas=collection_result.metadatas[0],
) )
bm25_retriever.k = k bm25_retriever.k = k
@ -135,9 +137,9 @@ def query_doc_with_hybrid_search(
result = compression_retriever.invoke(query) result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in collection_data] distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in collection_data] documents = [d.page_content for d in result]
metadatas = [d.metadata for d in collection_data] metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker: if k < k_reranker:
@ -146,7 +148,8 @@ def query_doc_with_hybrid_search(
) )
sorted_items = sorted_items[:k] sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items)) distances, documents, metadatas = map(list, zip(*sorted_items))
collection_data = {
result = {
"distances": [distances], "distances": [distances],
"documents": [documents], "documents": [documents],
"metadatas": [metadatas], "metadatas": [metadatas],
@ -154,9 +157,9 @@ def query_doc_with_hybrid_search(
log.info( log.info(
"query_doc_with_hybrid_search:result " "query_doc_with_hybrid_search:result "
+ f'{collection_data["metadatas"]} {collection_data["distances"]}' + f'{result["metadatas"]} {result["distances"]}'
) )
return collection_data return result
except Exception as e: except Exception as e:
raise e raise e
@ -279,20 +282,22 @@ def query_collection_with_hybrid_search(
error = False error = False
# Fetch collection data once per collection sequentially # Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later # Avoid fetching the same data multiple times later
collection_data = {} collection_results = {}
for collection_name in collection_names: for collection_name in collection_names:
try: try:
collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name) collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
except Exception as e: except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}") log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_data[collection_name] = None collection_results[collection_name] = None
for collection_name in collection_names: for collection_name in collection_names:
try: try:
for query in queries: for query in queries:
result = query_doc_with_hybrid_search( result = query_doc_with_hybrid_search(
collection_name=collection_name, collection_name=collection_name,
collection_data=collection_data[collection_name], collection_result=collection_results[collection_name],
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,