diff --git a/backend/open_webui/apps/rag/utils.py b/backend/open_webui/apps/rag/utils.py index 82daaf438..274a84c92 100644 --- a/backend/open_webui/apps/rag/utils.py +++ b/backend/open_webui/apps/rag/utils.py @@ -48,9 +48,9 @@ class VectorSearchRetriever(BaseRetriever): limit=self.top_k, ) - ids = result["ids"][0] - metadatas = result["metadatas"][0] - documents = result["documents"][0] + ids = result.ids[0] + metadatas = result.metadatas[0] + documents = result.documents[0] results = [] for idx in range(len(ids)): @@ -194,7 +194,7 @@ def query_collection( k=k, embedding_function=embedding_function, ) - results.append(result) + results.append(result.model_dump()) except Exception as e: log.exception(f"Error when querying the collection: {e}") else: @@ -212,7 +212,7 @@ def query_collection_with_hybrid_search( r: float, ) -> dict: results = [] - failed = 0 + error = False for collection_name in collection_names: try: result = query_doc_with_hybrid_search( @@ -228,12 +228,14 @@ def query_collection_with_hybrid_search( log.exception( "Error when querying the collection with " f"hybrid_search: {e}" ) - failed += 1 - if failed == len(collection_names): + error = True + + if error: raise Exception( "Hybrid search failed for all collections. Using " "Non hybrid search as fallback." ) + return merge_and_sort_query_results(results, k=k, reverse=True) diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index 9115fb9f5..7927d9ad4 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches from typing import Optional -from open_webui.apps.rag.vector.main import VectorItem, QueryResult +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -47,7 +47,7 @@ class ChromaClient: def search( self, collection_name: str, vectors: list[list[float | int]], limit: int - ) -> Optional[QueryResult]: + ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. collection = self.client.get_collection(name=collection_name) if collection: @@ -56,19 +56,31 @@ class ChromaClient: n_results=limit, ) - return { - "ids": result["ids"], - "distances": result["distances"], - "documents": result["documents"], - "metadatas": result["metadatas"], - } + return SearchResult( + **{ + "ids": result["ids"], + "distances": result["distances"], + "documents": result["documents"], + "metadatas": result["metadatas"], + } + ) return None - def get(self, collection_name: str) -> Optional[QueryResult]: + def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. collection = self.client.get_collection(name=collection_name) if collection: - return collection.get() + + result = collection.get() + + return GetResult( + **{ + "ids": [result["ids"]], + "distances": [result["distances"]], + "documents": [result["documents"]], + "metadatas": [result["metadatas"]], + } + ) return None def insert(self, collection_name: str, items: list[VectorItem]): diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index 260aa687e..6e98efeec 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -4,7 +4,7 @@ import json from typing import Optional -from open_webui.apps.rag.vector.main import VectorItem, QueryResult +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( MILVUS_URI, ) @@ -15,7 +15,7 @@ class MilvusClient: self.collection_prefix = "open_webui" self.client = Client(uri=MILVUS_URI) - def _result_to_query_result(self, result) -> QueryResult: + def _result_to_query_result(self, result) -> SearchResult: print(result) ids = [] @@ -40,12 +40,14 @@ class MilvusClient: documents.append(_documents) metadatas.append(_metadatas) - return { - "ids": ids, - "distances": distances, - "documents": documents, - "metadatas": metadatas, - } + return SearchResult( + **{ + "ids": ids, + "distances": distances, + "documents": documents, + "metadatas": metadatas, + } + ) def _create_collection(self, collection_name: str, dimension: int): schema = self.client.create_schema( @@ -94,7 +96,7 @@ class MilvusClient: def search( self, collection_name: str, vectors: list[list[float | int]], limit: int - ) -> Optional[QueryResult]: + ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. result = self.client.search( collection_name=f"{self.collection_prefix}_{collection_name}", @@ -105,10 +107,11 @@ class MilvusClient: return self._result_to_query_result(result) - def get(self, collection_name: str) -> Optional[QueryResult]: + def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. result = self.client.query( collection_name=f"{self.collection_prefix}_{collection_name}", + filter='id != ""', ) return self._result_to_query_result(result) diff --git a/backend/open_webui/apps/rag/vector/main.py b/backend/open_webui/apps/rag/vector/main.py index 5b5a8ea38..f0cf0c038 100644 --- a/backend/open_webui/apps/rag/vector/main.py +++ b/backend/open_webui/apps/rag/vector/main.py @@ -9,8 +9,11 @@ class VectorItem(BaseModel): metadata: Any -class QueryResult(BaseModel): +class GetResult(BaseModel): ids: Optional[List[List[str]]] - distances: Optional[List[List[float | int]]] documents: Optional[List[List[str]]] metadatas: Optional[List[List[Any]]] + + +class SearchResult(GetResult): + distances: Optional[List[List[float | int]]]