This commit is contained in:
Timothy J. Baek 2024-10-03 20:58:56 -07:00
parent 57360b7a61
commit 124a17e826

View File

@ -49,42 +49,49 @@ class ChromaClient:
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
return GetResult(
**{
"ids": result["ids"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
return GetResult(
**{
"ids": result["ids"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
except Exception as e:
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.