From 2fc07fd6a2ebb0ca0dae7f4cf4a5187eda36ca80 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 3 Oct 2024 06:53:21 -0700 Subject: [PATCH] enh: vector db hash collision check --- backend/open_webui/apps/retrieval/main.py | 10 +++++++++ .../apps/retrieval/vector/dbs/chroma.py | 21 +++++++++++++++++++ .../apps/retrieval/vector/dbs/milvus.py | 19 +++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 374cbb8a2..613271596 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -641,6 +641,16 @@ def save_docs_to_vector_db( ) -> bool: log.info(f"save_docs_to_vector_db {docs} {collection_name}") + # Check if entries with the same hash (metadata.hash) already exist + if metadata and "hash" in metadata: + existing_docs = VECTOR_DB_CLIENT.query( + collection_name=collection_name, + filter={"hash": metadata["hash"]}, + ) + if existing_docs: + log.info(f"Document with hash {metadata['hash']} already exists") + return True + if split: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.config.CHUNK_SIZE, diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index a73eb92dc..02d253bf6 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -66,6 +66,27 @@ class ChromaClient: ) return None + def query( + self, collection_name: str, filter: dict, limit: int = 1 + ) -> Optional[SearchResult]: + # Query the items from the collection based on the filter. + collection = self.client.get_collection(name=collection_name) + if collection: + result = collection.query( + where=filter, + n_results=limit, + ) + + return SearchResult( + **{ + "ids": result["ids"], + "distances": result["distances"], + "documents": result["documents"], + "metadatas": result["metadatas"], + } + ) + return None + def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. collection = self.client.get_collection(name=collection_name) diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py index 4c8305ba8..61b3c23c8 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py @@ -135,6 +135,25 @@ class MilvusClient: return self._result_to_search_result(result) + def query( + self, collection_name: str, filter: dict, limit: int = 1 + ) -> Optional[SearchResult]: + # Query the items from the collection based on the filter. + filter_string = " && ".join( + [ + f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" + for key, value in filter.items() + ] + ) + + result = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + filter=filter_string, + limit=limit, + ) + + return self._result_to_search_result([result]) + def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. result = self.client.query(