This commit is contained in:
Timothy J. Baek 2024-09-13 01:18:20 -04:00
parent b943b7d337
commit 939bfd153e
4 changed files with 49 additions and 29 deletions

View File

@ -48,9 +48,9 @@ class VectorSearchRetriever(BaseRetriever):
limit=self.top_k, limit=self.top_k,
) )
ids = result["ids"][0] ids = result.ids[0]
metadatas = result["metadatas"][0] metadatas = result.metadatas[0]
documents = result["documents"][0] documents = result.documents[0]
results = [] results = []
for idx in range(len(ids)): for idx in range(len(ids)):
@ -194,7 +194,7 @@ def query_collection(
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
results.append(result) results.append(result.model_dump())
except Exception as e: except Exception as e:
log.exception(f"Error when querying the collection: {e}") log.exception(f"Error when querying the collection: {e}")
else: else:
@ -212,7 +212,7 @@ def query_collection_with_hybrid_search(
r: float, r: float,
) -> dict: ) -> dict:
results = [] results = []
failed = 0 error = False
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_doc_with_hybrid_search( result = query_doc_with_hybrid_search(
@ -228,12 +228,14 @@ def query_collection_with_hybrid_search(
log.exception( log.exception(
"Error when querying the collection with " f"hybrid_search: {e}" "Error when querying the collection with " f"hybrid_search: {e}"
) )
failed += 1 error = True
if failed == len(collection_names):
if error:
raise Exception( raise Exception(
"Hybrid search failed for all collections. Using " "Hybrid search failed for all collections. Using "
"Non hybrid search as fallback." "Non hybrid search as fallback."
) )
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)

View File

@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import ( from open_webui.config import (
CHROMA_DATA_PATH, CHROMA_DATA_PATH,
CHROMA_HTTP_HOST, CHROMA_HTTP_HOST,
@ -47,7 +47,7 @@ class ChromaClient:
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
@ -56,19 +56,31 @@ class ChromaClient:
n_results=limit, n_results=limit,
) )
return { return SearchResult(
**{
"ids": result["ids"], "ids": result["ids"],
"distances": result["distances"], "distances": result["distances"],
"documents": result["documents"], "documents": result["documents"],
"metadatas": result["metadatas"], "metadatas": result["metadatas"],
} }
)
return None return None
def get(self, collection_name: str) -> Optional[QueryResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
return collection.get()
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"distances": [result["distances"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None return None
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):

View File

@ -4,7 +4,7 @@ import json
from typing import Optional from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import ( from open_webui.config import (
MILVUS_URI, MILVUS_URI,
) )
@ -15,7 +15,7 @@ class MilvusClient:
self.collection_prefix = "open_webui" self.collection_prefix = "open_webui"
self.client = Client(uri=MILVUS_URI) self.client = Client(uri=MILVUS_URI)
def _result_to_query_result(self, result) -> QueryResult: def _result_to_query_result(self, result) -> SearchResult:
print(result) print(result)
ids = [] ids = []
@ -40,12 +40,14 @@ class MilvusClient:
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return { return SearchResult(
**{
"ids": ids, "ids": ids,
"distances": distances, "distances": distances,
"documents": documents, "documents": documents,
"metadatas": metadatas, "metadatas": metadatas,
} }
)
def _create_collection(self, collection_name: str, dimension: int): def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema( schema = self.client.create_schema(
@ -94,7 +96,7 @@ class MilvusClient:
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
@ -105,10 +107,11 @@ class MilvusClient:
return self._result_to_query_result(result) return self._result_to_query_result(result)
def get(self, collection_name: str) -> Optional[QueryResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
result = self.client.query( result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
) )
return self._result_to_query_result(result) return self._result_to_query_result(result)

View File

@ -9,8 +9,11 @@ class VectorItem(BaseModel):
metadata: Any metadata: Any
class QueryResult(BaseModel): class GetResult(BaseModel):
ids: Optional[List[List[str]]] ids: Optional[List[List[str]]]
distances: Optional[List[List[float | int]]]
documents: Optional[List[List[str]]] documents: Optional[List[List[str]]]
metadatas: Optional[List[List[Any]]] metadatas: Optional[List[List[Any]]]
class SearchResult(GetResult):
distances: Optional[List[List[float | int]]]