mirror of
https://github.com/open-webui/open-webui
synced 2025-04-05 05:10:46 +00:00
Avoid multiple data fetching
This commit is contained in:
parent
4547453141
commit
04bf9ddab2
@ -1,9 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import requests
|
import requests
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
@ -12,10 +10,8 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
|
|||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import VECTOR_DB
|
from open_webui.config import VECTOR_DB
|
||||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||||
from open_webui.utils.misc import get_last_user_message, calculate_sha256_string
|
|
||||||
|
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.models.files import Files
|
from open_webui.models.files import Files
|
||||||
@ -102,6 +98,7 @@ def get_doc(collection_name: str, user: UserModel = None):
|
|||||||
|
|
||||||
def query_doc_with_hybrid_search(
|
def query_doc_with_hybrid_search(
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
|
collection_data,
|
||||||
query: str,
|
query: str,
|
||||||
embedding_function,
|
embedding_function,
|
||||||
k: int,
|
k: int,
|
||||||
@ -110,11 +107,9 @@ def query_doc_with_hybrid_search(
|
|||||||
r: float,
|
r: float,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
try:
|
try:
|
||||||
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
|
||||||
|
|
||||||
bm25_retriever = BM25Retriever.from_texts(
|
bm25_retriever = BM25Retriever.from_texts(
|
||||||
texts=result.documents[0],
|
texts=collection_data.documents[0],
|
||||||
metadatas=result.metadatas[0],
|
metadatas=collection_data.metadatas[0],
|
||||||
)
|
)
|
||||||
bm25_retriever.k = k
|
bm25_retriever.k = k
|
||||||
|
|
||||||
@ -140,9 +135,9 @@ def query_doc_with_hybrid_search(
|
|||||||
|
|
||||||
result = compression_retriever.invoke(query)
|
result = compression_retriever.invoke(query)
|
||||||
|
|
||||||
distances = [d.metadata.get("score") for d in result]
|
distances = [d.metadata.get("score") for d in collection_data]
|
||||||
documents = [d.page_content for d in result]
|
documents = [d.page_content for d in collection_data]
|
||||||
metadatas = [d.metadata for d in result]
|
metadatas = [d.metadata for d in collection_data]
|
||||||
|
|
||||||
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
|
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
|
||||||
if k < k_reranker:
|
if k < k_reranker:
|
||||||
@ -151,7 +146,7 @@ def query_doc_with_hybrid_search(
|
|||||||
)
|
)
|
||||||
sorted_items = sorted_items[:k]
|
sorted_items = sorted_items[:k]
|
||||||
distances, documents, metadatas = map(list, zip(*sorted_items))
|
distances, documents, metadatas = map(list, zip(*sorted_items))
|
||||||
result = {
|
collection_data = {
|
||||||
"distances": [distances],
|
"distances": [distances],
|
||||||
"documents": [documents],
|
"documents": [documents],
|
||||||
"metadatas": [metadatas],
|
"metadatas": [metadatas],
|
||||||
@ -159,9 +154,9 @@ def query_doc_with_hybrid_search(
|
|||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"query_doc_with_hybrid_search:result "
|
"query_doc_with_hybrid_search:result "
|
||||||
+ f'{result["metadatas"]} {result["distances"]}'
|
+ f'{collection_data["metadatas"]} {collection_data["distances"]}'
|
||||||
)
|
)
|
||||||
return result
|
return collection_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -282,11 +277,22 @@ def query_collection_with_hybrid_search(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
results = []
|
results = []
|
||||||
error = False
|
error = False
|
||||||
|
# Fetch collection data once per collection sequentially
|
||||||
|
# Avoid fetching the same data multiple times later
|
||||||
|
collection_data = {}
|
||||||
|
for collection_name in collection_names:
|
||||||
|
try:
|
||||||
|
collection_data[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Failed to fetch collection {collection_name}: {e}")
|
||||||
|
collection_data[collection_name] = None
|
||||||
|
|
||||||
for collection_name in collection_names:
|
for collection_name in collection_names:
|
||||||
try:
|
try:
|
||||||
for query in queries:
|
for query in queries:
|
||||||
result = query_doc_with_hybrid_search(
|
result = query_doc_with_hybrid_search(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
|
collection_data=collection_data[collection_name],
|
||||||
query=query,
|
query=query,
|
||||||
embedding_function=embedding_function,
|
embedding_function=embedding_function,
|
||||||
k=k,
|
k=k,
|
||||||
|
Loading…
Reference in New Issue
Block a user