This commit is contained in:
Timothy J. Baek 2024-10-03 20:51:21 -07:00
parent 2fc07fd6a2
commit 57360b7a61
2 changed files with 6 additions and 7 deletions

View File

@ -68,19 +68,18 @@ class ChromaClient:
def query( def query(
self, collection_name: str, filter: dict, limit: int = 1 self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[SearchResult]: ) -> Optional[GetResult]:
# Query the items from the collection based on the filter. # Query the items from the collection based on the filter.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
result = collection.query( result = collection.get(
where=filter, where=filter,
n_results=limit, limit=limit,
) )
return SearchResult( return GetResult(
**{ **{
"ids": result["ids"], "ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"], "documents": result["documents"],
"metadatas": result["metadatas"], "metadatas": result["metadatas"],
} }

View File

@ -137,7 +137,7 @@ class MilvusClient:
def query( def query(
self, collection_name: str, filter: dict, limit: int = 1 self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[SearchResult]: ) -> Optional[GetResult]:
# Query the items from the collection based on the filter. # Query the items from the collection based on the filter.
filter_string = " && ".join( filter_string = " && ".join(
[ [
@ -152,7 +152,7 @@ class MilvusClient:
limit=limit, limit=limit,
) )
return self._result_to_search_result([result]) return self._result_to_get_result([result])
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.