From ff8a55a861836aa27f00a53638de4865e26cef52 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 24 Mar 2024 00:41:41 -0700 Subject: [PATCH] refac: rag api --- backend/apps/rag/main.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 0ffc32a82..a8606b398 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -303,14 +303,14 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b def store_text_in_vector_db( - text, name, collection_name, overwrite: bool = False + text, metadata, collection_name, overwrite: bool = False ) -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP, add_start_index=True, ) - docs = text_splitter.create_documents([text], metadatas=[{"name": name}]) + docs = text_splitter.create_documents([text], metadatas=[metadata]) return store_docs_in_vector_db(docs, collection_name, overwrite) @@ -493,7 +493,11 @@ def store_text( if collection_name == None: collection_name = calculate_sha256_string(form_data.content) - result = store_text_in_vector_db(form_data.content, form_data.name, collection_name) + result = store_text_in_vector_db( + form_data.content, + metadata={"name": form_data.name, "created_by": user.id}, + collection_name=collection_name, + ) if result: return {"status": True, "collection_name": collection_name}