From 1f9b5b64563f32364d399ab22e641d040b61ddf3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 5 Oct 2024 09:58:46 -0700 Subject: [PATCH] refac: retain metadata for collection --- backend/open_webui/apps/retrieval/main.py | 49 +++++++++++------- .../apps/retrieval/vector/dbs/chroma.py | 5 +- .../apps/retrieval/vector/dbs/milvus.py | 50 +++++++++++++++---- 3 files changed, 73 insertions(+), 31 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index fe891a8ac..d5096a9b2 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -733,15 +733,10 @@ def process_file( file = Files.get_file_by_id(form_data.file_id) collection_name = form_data.collection_name + if collection_name is None: collection_name = f"file-{file.id}" - loader = Loader( - engine=app.state.config.CONTENT_EXTRACTION_ENGINE, - TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, - PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, - ) - if form_data.content: docs = [ Document( @@ -755,21 +750,41 @@ def process_file( ] text_content = form_data.content - elif file.data.get("content", None): - docs = [ - Document( - page_content=file.data.get("content", ""), - metadata={ - "name": file.meta.get("name", file.filename), - "created_by": file.user_id, - **file.meta, - }, - ) - ] + elif form_data.collection_name: + result = VECTOR_DB_CLIENT.query( + collection_name=f"file-{file.id}", filter={"file_id": file.id} + ) + + if result: + docs = [ + Document( + page_content=result.documents[0][idx], + metadata=result.metadatas[0][idx], + ) + for idx, id in enumerate(result.ids[0]) + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + "name": file.meta.get("name", file.filename), + "created_by": file.user_id, + **file.meta, + }, + ) + ] + text_content = file.data.get("content", "") else: file_path = file.meta.get("path", None) if file_path: + loader = Loader( + engine=app.state.config.CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL, + PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, + ) + docs = loader.load( file.filename, file.meta.get("content_type"), file_path ) diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index 00b4af441..84f80b253 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -70,10 +70,9 @@ class ChromaClient: return None def query( - self, collection_name: str, filter: dict, limit: int = 2 + self, collection_name: str, filter: dict, limit: Optional[int] = None ) -> Optional[GetResult]: # Query the items from the collection based on the filter. - try: collection = self.client.get_collection(name=collection_name) if collection: @@ -82,8 +81,6 @@ class ChromaClient: limit=limit, ) - print(result) - return GetResult( **{ "ids": [result["ids"]], diff --git a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py index b5bbb24b3..bf33d35e9 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/milvus.py @@ -135,10 +135,8 @@ class MilvusClient: return self._result_to_search_result(result) - def query( - self, collection_name: str, filter: dict, limit: int = 1 - ) -> Optional[GetResult]: - # Query the items from the collection based on the filter. + def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + # Construct the filter string for querying filter_string = " && ".join( [ f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" @@ -146,13 +144,45 @@ class MilvusClient: ] ) - result = self.client.query( - collection_name=f"{self.collection_prefix}_{collection_name}", - filter=filter_string, - limit=limit, - ) + max_limit = 16383 # The maximum number of records per request + all_results = [] - return self._result_to_get_result([result]) + if limit is None: + limit = float("inf") # Use infinity as a placeholder for no limit + + # Initialize offset and remaining to handle pagination + offset = 0 + remaining = limit + + # Loop until there are no more items to fetch or the desired limit is reached + while remaining > 0: + current_fetch = min( + max_limit, remaining + ) # Determine how many items to fetch in this iteration + + results = self.client.query( + collection_name=f"{self.collection_prefix}_{collection_name}", + filter=filter_string, + output_fields=["*"], + limit=current_fetch, + offset=offset, + ) + + if not results: + break + + all_results.extend(results) + results_count = len(results) + remaining -= ( + results_count # Decrease remaining by the number of items fetched + ) + offset += results_count + + # Break the loop if the results returned are less than the requested fetch count + if results_count < current_fetch: + break + + return self._result_to_get_result(all_results) def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection.