refac: comments

This commit is contained in:
Timothy J. Baek 2024-09-10 04:46:40 +01:00
parent 522afbb0a0
commit 0886b3a0a4

View File

@ -37,18 +37,22 @@ class ChromaClient:
) )
def list_collections(self) -> list[str]: def list_collections(self) -> list[str]:
# List all the collections in the database.
collections = self.client.list_collections() collections = self.client.list_collections()
return [collection.name for collection in collections] return [collection.name for collection in collections]
def create_collection(self, collection_name: str): def create_collection(self, collection_name: str):
# Create a new collection based on the collection name.
return self.client.create_collection(name=collection_name) return self.client.create_collection(name=collection_name)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
return self.client.delete_collection(name=collection_name) return self.client.delete_collection(name=collection_name)
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]: ) -> Optional[QueryResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
result = collection.query( result = collection.query(
@ -65,12 +69,14 @@ class ChromaClient:
return None return None
def get(self, collection_name: str) -> Optional[QueryResult]: def get(self, collection_name: str) -> Optional[QueryResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
return collection.get() return collection.get()
return None return None
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection.
collection = self.client.get_or_create_collection(name=collection_name) collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
@ -88,6 +94,7 @@ class ChromaClient:
collection.add(*batch) collection.add(*batch)
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them.
collection = self.client.get_or_create_collection(name=collection_name) collection = self.client.get_or_create_collection(name=collection_name)
ids = [item["id"] for item in items] ids = [item["id"] for item in items]
@ -100,9 +107,11 @@ class ChromaClient:
) )
def delete(self, collection_name: str, ids: list[str]): def delete(self, collection_name: str, ids: list[str]):
# Delete the items from the collection based on the ids.
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
collection.delete(ids=ids) collection.delete(ids=ids)
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries.
return self.client.reset() return self.client.reset()