Merge pull request #12132 from Phlogi/dev-fetch-documents-once

Avoid multiple data fetching
This commit is contained in:
Timothy Jaeryang Baek 2025-03-30 20:44:32 -07:00 committed by GitHub
commit ce0d82b55f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,9 +1,7 @@
import logging
import os
import uuid
from typing import Optional, Union
import asyncio
import requests
import hashlib
@ -12,10 +10,8 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
from langchain_community.retrievers import BM25Retriever
from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
from open_webui.models.users import UserModel
from open_webui.models.files import Files
@ -102,6 +98,7 @@ def get_doc(collection_name: str, user: UserModel = None):
def query_doc_with_hybrid_search(
collection_name: str,
collection_data,
query: str,
embedding_function,
k: int,
@ -110,11 +107,9 @@ def query_doc_with_hybrid_search(
r: float,
) -> dict:
try:
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
bm25_retriever = BM25Retriever.from_texts(
texts=result.documents[0],
metadatas=result.metadatas[0],
texts=collection_data.documents[0],
metadatas=collection_data.metadatas[0],
)
bm25_retriever.k = k
@ -140,9 +135,9 @@ def query_doc_with_hybrid_search(
result = compression_retriever.invoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
distances = [d.metadata.get("score") for d in collection_data]
documents = [d.page_content for d in collection_data]
metadatas = [d.metadata for d in collection_data]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
@ -151,7 +146,7 @@ def query_doc_with_hybrid_search(
)
sorted_items = sorted_items[:k]
distances, documents, metadatas = map(list, zip(*sorted_items))
result = {
collection_data = {
"distances": [distances],
"documents": [documents],
"metadatas": [metadatas],
@ -159,9 +154,9 @@ def query_doc_with_hybrid_search(
log.info(
"query_doc_with_hybrid_search:result "
+ f'{result["metadatas"]} {result["distances"]}'
+ f'{collection_data["metadatas"]} {collection_data["distances"]}'
)
return result
return collection_data
except Exception as e:
raise e
@ -282,11 +277,22 @@ def query_collection_with_hybrid_search(
) -> dict:
results = []
error = False
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
collection_data = {}
for collection_name in collection_names:
try:
collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
except Exception as e:
log.exception(f"Failed to fetch collection {collection_name}: {e}")
collection_data[collection_name] = None
for collection_name in collection_names:
try:
for query in queries:
result = query_doc_with_hybrid_search(
collection_name=collection_name,
collection_data=collection_data[collection_name],
query=query,
embedding_function=embedding_function,
k=k,