fix: milvus

This commit is contained in:
Timothy J. Baek 2024-10-06 17:58:09 -07:00
parent c8e609c3d1
commit 05c15b017d

View File

@ -24,7 +24,6 @@ class MilvusClient:
_ids = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
@ -112,12 +111,14 @@ class MilvusClient:
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
@ -126,6 +127,7 @@ class MilvusClient:
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
@ -137,9 +139,13 @@ class MilvusClient:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
# Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
return None
filter_string = " && ".join(
[
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
@ -154,38 +160,45 @@ class MilvusClient:
offset = 0
remaining = limit
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
try:
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
print("remaining", remaining)
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=["*"],
limit=current_fetch,
offset=offset,
)
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=["*"],
limit=current_fetch,
offset=offset,
)
if not results:
break
if not results:
break
all_results.extend(results)
results_count = len(results)
remaining -= (
results_count # Decrease remaining by the number of items fetched
)
offset += results_count
all_results.extend(results)
results_count = len(results)
remaining -= (
results_count # Decrease remaining by the number of items fetched
)
offset += results_count
# Break the loop if the results returned are less than the requested fetch count
if results_count < current_fetch:
break
# Break the loop if the results returned are less than the requested fetch count
if results_count < current_fetch:
break
return self._result_to_get_result(all_results)
print(all_results)
return self._result_to_get_result([all_results])
except Exception as e:
print(e)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection_name = collection_name.replace("-", "_")
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
@ -194,6 +207,7 @@ class MilvusClient:
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
@ -216,6 +230,7 @@ class MilvusClient:
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
@ -243,7 +258,7 @@ class MilvusClient:
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
collection_name = collection_name.replace("-", "_")
if ids:
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
@ -253,7 +268,7 @@ class MilvusClient:
# Convert the filter dictionary to a string using JSON_CONTAINS.
filter_string = " && ".join(
[
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
@ -265,7 +280,6 @@ class MilvusClient:
def reset(self):
# Resets the database. This will delete all collections and item entries.
collection_names = self.client.list_collections()
for collection_name in collection_names:
if collection_name.startswith(self.collection_prefix):