diff --git a/backend/open_webui/apps/rag/main.py b/backend/open_webui/apps/rag/main.py index 32ca6b7e9..a0d4c7d28 100644 --- a/backend/open_webui/apps/rag/main.py +++ b/backend/open_webui/apps/rag/main.py @@ -997,7 +997,7 @@ def store_docs_in_vector_db( try: if overwrite: - if collection_name in VECTOR_DB_CLIENT.list_collections(): + if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): log.info(f"deleting existing collection {collection_name}") VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name) diff --git a/backend/open_webui/apps/rag/vector/dbs/chroma.py b/backend/open_webui/apps/rag/vector/dbs/chroma.py index b04dbd6bc..693b41f23 100644 --- a/backend/open_webui/apps/rag/vector/dbs/chroma.py +++ b/backend/open_webui/apps/rag/vector/dbs/chroma.py @@ -36,10 +36,10 @@ class ChromaClient: database=CHROMA_DATABASE, ) - def list_collections(self) -> list[str]: - # List all the collections in the database. + def has_collection(self, collection_name: str) -> bool: + # Check if the collection exists based on the collection name. collections = self.client.list_collections() - return [collection.name for collection in collections] + return collection_name in [collection.name for collection in collections] def delete_collection(self, collection_name: str): # Delete the collection based on the collection name. diff --git a/backend/open_webui/apps/rag/vector/dbs/milvus.py b/backend/open_webui/apps/rag/vector/dbs/milvus.py index f679b4504..260aa687e 100644 --- a/backend/open_webui/apps/rag/vector/dbs/milvus.py +++ b/backend/open_webui/apps/rag/vector/dbs/milvus.py @@ -80,13 +80,11 @@ class MilvusClient: index_params=index_params, ) - def list_collections(self) -> list[str]: - # List all the collections in the database. - return [ - collection[len(self.collection_prefix) :] - for collection in self.client.list_collections() - if collection.startswith(self.collection_prefix) - ] + def has_collection(self, collection_name: str) -> bool: + # Check if the collection exists based on the collection name. + return self.client.has_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) def delete_collection(self, collection_name: str): # Delete the collection based on the collection name.