diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py index bf33d35e9..5351f860e 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py @@ -24,7 +24,6 @@ class MilvusClient: _ids = [] _documents = [] _metadatas = [] - for item in match: _ids.append(item.get("id")) _documents.append(item.get("data", {}).get("text")) @@ -112,12 +111,14 @@ class MilvusClient: def has_collection(self, collection_name: str) -> bool: # Check if the collection exists based on the collection name. + collection_name = collection_name.replace("-", "_") return self.client.has_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ) def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. + collection_name = collection_name.replace("-", "_") return self.client.drop_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ) @@ -126,6 +127,7 @@ class MilvusClient: self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. + collection_name = collection_name.replace("-", "_") result = self.client.search( collection_name=f"{self.collection_prefix}_{collection_name}", data=vectors, @@ -137,9 +139,13 @@ class MilvusClient: def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): # Construct the filter string for querying + collection_name = collection_name.replace("-", "_") + if not self.has_collection(collection_name): + return None + filter_string = " && ".join( [ - f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" + f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items() ] ) @@ -154,38 +160,45 @@ class MilvusClient: offset = 0 remaining = limit - # Loop until there are no more items to fetch or the desired limit is reached - while remaining > 0: - current_fetch = min( - max_limit, remaining - ) # Determine how many items to fetch in this iteration + try: + # Loop until there are no more items to fetch or the desired limit is reached + while remaining > 0: + print("remaining", remaining) + current_fetch = min( + max_limit, remaining + ) # Determine how many items to fetch in this iteration - results = self.client.query( - collection_name=f"{self.collection_prefix}_{collection_name}", - filter=filter_string, - output_fields=["*"], - limit=current_fetch, - offset=offset, - ) + results = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + filter=filter_string, + output_fields=["*"], + limit=current_fetch, + offset=offset, + ) - if not results: - break + if not results: + break - all_results.extend(results) - results_count = len(results) - remaining -= ( - results_count # Decrease remaining by the number of items fetched - ) - offset += results_count + all_results.extend(results) + results_count = len(results) + remaining -= ( + results_count # Decrease remaining by the number of items fetched + ) + offset += results_count - # Break the loop if the results returned are less than the requested fetch count - if results_count < current_fetch: - break + # Break the loop if the results returned are less than the requested fetch count + if results_count < current_fetch: + break - return self._result_to_get_result(all_results) + print(all_results) + return self._result_to_get_result([all_results]) + except Exception as e: + print(e) + return None def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. + collection_name = collection_name.replace("-", "_") result = self.client.query( collection_name=f"{self.collection_prefix}_{collection_name}", filter='id != ""', @@ -194,6 +207,7 @@ class MilvusClient: def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. + collection_name = collection_name.replace("-", "_") if not self.client.has_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ): @@ -216,6 +230,7 @@ class MilvusClient: def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. + collection_name = collection_name.replace("-", "_") if not self.client.has_collection( collection_name=f"{self.collection_prefix}_{collection_name}" ): @@ -243,7 +258,7 @@ class MilvusClient: filter: Optional[dict] = None, ): # Delete the items from the collection based on the ids. - + collection_name = collection_name.replace("-", "_") if ids: return self.client.delete( collection_name=f"{self.collection_prefix}_{collection_name}", @@ -253,7 +268,7 @@ class MilvusClient: # Convert the filter dictionary to a string using JSON_CONTAINS. filter_string = " && ".join( [ - f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" + f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items() ] ) @@ -265,7 +280,6 @@ class MilvusClient: def reset(self): # Resets the database. This will delete all collections and item entries. - collection_names = self.client.list_collections() for collection_name in collection_names: if collection_name.startswith(self.collection_prefix):