fix: milvus

This commit is contained in:
Timothy J. Baek 2024-10-06 17:58:09 -07:00
parent c8e609c3d1
commit 05c15b017d

View File

@ -24,7 +24,6 @@ class MilvusClient:
_ids = [] _ids = []
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text")) _documents.append(item.get("data", {}).get("text"))
@ -112,12 +111,14 @@ class MilvusClient:
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name. # Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.has_collection( return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
) )
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name. # Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.drop_collection( return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
) )
@ -126,6 +127,7 @@ class MilvusClient:
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors, data=vectors,
@ -137,9 +139,13 @@ class MilvusClient:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying # Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
return None
filter_string = " && ".join( filter_string = " && ".join(
[ [
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
@ -154,8 +160,10 @@ class MilvusClient:
offset = 0 offset = 0
remaining = limit remaining = limit
try:
# Loop until there are no more items to fetch or the desired limit is reached # Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0: while remaining > 0:
print("remaining", remaining)
current_fetch = min( current_fetch = min(
max_limit, remaining max_limit, remaining
) # Determine how many items to fetch in this iteration ) # Determine how many items to fetch in this iteration
@ -182,10 +190,15 @@ class MilvusClient:
if results_count < current_fetch: if results_count < current_fetch:
break break
return self._result_to_get_result(all_results) print(all_results)
return self._result_to_get_result([all_results])
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
collection_name = collection_name.replace("-", "_")
result = self.client.query( result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""', filter='id != ""',
@ -194,6 +207,7 @@ class MilvusClient:
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
@ -216,6 +230,7 @@ class MilvusClient:
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection( if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}" collection_name=f"{self.collection_prefix}_{collection_name}"
): ):
@ -243,7 +258,7 @@ class MilvusClient:
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids. # Delete the items from the collection based on the ids.
collection_name = collection_name.replace("-", "_")
if ids: if ids:
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f"{self.collection_prefix}_{collection_name}",
@ -253,7 +268,7 @@ class MilvusClient:
# Convert the filter dictionary to a string using JSON_CONTAINS. # Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join( filter_string = " && ".join(
[ [
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items() for key, value in filter.items()
] ]
) )
@ -265,7 +280,6 @@ class MilvusClient:
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries. # Resets the database. This will delete all collections and item entries.
collection_names = self.client.list_collections() collection_names = self.client.list_collections()
for collection_name in collection_names: for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix): if collection_name.startswith(self.collection_prefix):