From 04bf9ddab2f61fd521a3dc7f3997ea195334c7a6 Mon Sep 17 00:00:00 2001
From: Phlogi <Phlogi@users.noreply.github.com>
Date: Thu, 27 Mar 2025 19:05:20 +0100
Subject: [PATCH] Avoid multiple data fetching

---
 backend/open_webui/retrieval/utils.py | 34 ++++++++++++++++-----------
 1 file changed, 20 insertions(+), 14 deletions(-)

diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index aa7b57dfc..767d5cce8 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -1,9 +1,7 @@
 import logging
 import os
-import uuid
 from typing import Optional, Union
 
-import asyncio
 import requests
 import hashlib
 
@@ -12,10 +10,8 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
 from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
 
-
 from open_webui.config import VECTOR_DB
 from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
 
 from open_webui.models.users import UserModel
 from open_webui.models.files import Files
@@ -102,6 +98,7 @@ def get_doc(collection_name: str, user: UserModel = None):
 
 def query_doc_with_hybrid_search(
     collection_name: str,
+    collection_data,
     query: str,
     embedding_function,
     k: int,
@@ -110,11 +107,9 @@ def query_doc_with_hybrid_search(
     r: float,
 ) -> dict:
     try:
-        result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
-
         bm25_retriever = BM25Retriever.from_texts(
-            texts=result.documents[0],
-            metadatas=result.metadatas[0],
+            texts=collection_data.documents[0],
+            metadatas=collection_data.metadatas[0],
         )
         bm25_retriever.k = k
 
@@ -140,9 +135,9 @@ def query_doc_with_hybrid_search(
 
         result = compression_retriever.invoke(query)
 
-        distances = [d.metadata.get("score") for d in result]
-        documents = [d.page_content for d in result]
-        metadatas = [d.metadata for d in result]
+        distances = [d.metadata.get("score") for d in collection_data]
+        documents = [d.page_content for d in collection_data]
+        metadatas = [d.metadata for d in collection_data]
 
         # retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
         if k < k_reranker:
@@ -151,7 +146,7 @@ def query_doc_with_hybrid_search(
             )
             sorted_items = sorted_items[:k]
             distances, documents, metadatas = map(list, zip(*sorted_items))
-        result = {
+        collection_data = {
             "distances": [distances],
             "documents": [documents],
             "metadatas": [metadatas],
@@ -159,9 +154,9 @@ def query_doc_with_hybrid_search(
 
         log.info(
             "query_doc_with_hybrid_search:result "
-            + f'{result["metadatas"]} {result["distances"]}'
+            + f'{collection_data["metadatas"]} {collection_data["distances"]}'
         )
-        return result
+        return collection_data
     except Exception as e:
         raise e
 
@@ -282,11 +277,22 @@ def query_collection_with_hybrid_search(
 ) -> dict:
     results = []
     error = False
+    # Fetch collection data once per collection sequentially
+    # Avoid fetching the same data multiple times later
+    collection_data = {}
+    for collection_name in collection_names:
+        try:
+            collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
+        except Exception as e:
+            log.exception(f"Failed to fetch collection {collection_name}: {e}")
+            collection_data[collection_name] = None
+
     for collection_name in collection_names:
         try:
             for query in queries:
                 result = query_doc_with_hybrid_search(
                     collection_name=collection_name,
+                    collection_data=collection_data[collection_name],
                     query=query,
                     embedding_function=embedding_function,
                     k=k,