refac: retain metadata for collection

This commit is contained in:
Timothy J. Baek 2024-10-05 09:58:46 -07:00
parent 4ca870bf6d
commit 1f9b5b6456
3 changed files with 73 additions and 31 deletions

View File

@ -733,15 +733,10 @@ def process_file(
file = Files.get_file_by_id(form_data.file_id) file = Files.get_file_by_id(form_data.file_id)
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name is None: if collection_name is None:
collection_name = f"file-{file.id}" collection_name = f"file-{file.id}"
loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
)
if form_data.content: if form_data.content:
docs = [ docs = [
Document( Document(
@ -755,21 +750,41 @@ def process_file(
] ]
text_content = form_data.content text_content = form_data.content
elif file.data.get("content", None): elif form_data.collection_name:
docs = [ result = VECTOR_DB_CLIENT.query(
Document( collection_name=f"file-{file.id}", filter={"file_id": file.id}
page_content=file.data.get("content", ""), )
metadata={
"name": file.meta.get("name", file.filename), if result:
"created_by": file.user_id, docs = [
**file.meta, Document(
}, page_content=result.documents[0][idx],
) metadata=result.metadatas[0][idx],
] )
for idx, id in enumerate(result.ids[0])
]
else:
docs = [
Document(
page_content=file.data.get("content", ""),
metadata={
"name": file.meta.get("name", file.filename),
"created_by": file.user_id,
**file.meta,
},
)
]
text_content = file.data.get("content", "") text_content = file.data.get("content", "")
else: else:
file_path = file.meta.get("path", None) file_path = file.meta.get("path", None)
if file_path: if file_path:
loader = Loader(
engine=app.state.config.CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL=app.state.config.TIKA_SERVER_URL,
PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES,
)
docs = loader.load( docs = loader.load(
file.filename, file.meta.get("content_type"), file_path file.filename, file.meta.get("content_type"), file_path
) )

View File

@ -70,10 +70,9 @@ class ChromaClient:
return None return None
def query( def query(
self, collection_name: str, filter: dict, limit: int = 2 self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
# Query the items from the collection based on the filter. # Query the items from the collection based on the filter.
try: try:
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
if collection: if collection:
@ -82,8 +81,6 @@ class ChromaClient:
limit=limit, limit=limit,
) )
print(result)
return GetResult( return GetResult(
**{ **{
"ids": [result["ids"]], "ids": [result["ids"]],

View File

@ -135,10 +135,8 @@ class MilvusClient:
return self._result_to_search_result(result) return self._result_to_search_result(result)
def query( def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
self, collection_name: str, filter: dict, limit: int = 1 # Construct the filter string for querying
) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
filter_string = " && ".join( filter_string = " && ".join(
[ [
f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')" f"JSON_CONTAINS(metadata[{key}], '{[value] if isinstance(value, str) else value}')"
@ -146,13 +144,45 @@ class MilvusClient:
] ]
) )
result = self.client.query( max_limit = 16383 # The maximum number of records per request
collection_name=f"{self.collection_prefix}_{collection_name}", all_results = []
filter=filter_string,
limit=limit,
)
return self._result_to_get_result([result]) if limit is None:
limit = float("inf") # Use infinity as a placeholder for no limit
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
current_fetch = min(
max_limit, remaining
) # Determine how many items to fetch in this iteration
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
output_fields=["*"],
limit=current_fetch,
offset=offset,
)
if not results:
break
all_results.extend(results)
results_count = len(results)
remaining -= (
results_count # Decrease remaining by the number of items fetched
)
offset += results_count
# Break the loop if the results returned are less than the requested fetch count
if results_count < current_fetch:
break
return self._result_to_get_result(all_results)
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.