enh: vector db hash collision check

This commit is contained in:
Timothy J. Baek 2024-10-03 06:53:21 -07:00
parent 78413d0c2e
commit 2fc07fd6a2
3 changed files with 50 additions and 0 deletions

View File

@ -641,6 +641,16 @@ def save_docs_to_vector_db(
) -> bool:
log.info(f"save_docs_to_vector_db {docs} {collection_name}")
# Check if entries with the same hash (metadata.hash) already exist
if metadata and "hash" in metadata:
existing_docs = VECTOR_DB_CLIENT.query(
collection_name=collection_name,
filter={"hash": metadata["hash"]},
)
if existing_docs:
log.info(f"Document with hash {metadata['hash']} already exists")
return True
if split:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE,

View File

@ -66,6 +66,27 @@ class ChromaClient:
)
return None
def query(
self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[SearchResult]:
# Query the items from the collection based on the filter.
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
where=filter,
n_results=limit,
)
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name)

View File

@ -135,6 +135,25 @@ class MilvusClient:
return self._result_to_search_result(result)
def query(
self, collection_name: str, filter: dict, limit: int = 1
) -> Optional[SearchResult]:
# Query the items from the collection based on the filter.
filter_string = " && ".join(
[
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
for key, value in filter.items()
]
)
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
limit=limit,
)
return self._result_to_search_result([result])
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
result = self.client.query(