diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 5fc38b4a8..0ffc32a82 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -111,39 +111,6 @@ class StoreWebForm(CollectionNameForm): url: str -def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP - ) - docs = text_splitter.split_documents(data) - - texts = [doc.page_content for doc in docs] - metadatas = [doc.metadata for doc in docs] - - try: - if overwrite: - for collection in CHROMA_CLIENT.list_collections(): - if collection_name == collection.name: - print(f"deleting existing collection {collection_name}") - CHROMA_CLIENT.delete_collection(name=collection_name) - - collection = CHROMA_CLIENT.create_collection( - name=collection_name, - embedding_function=app.state.sentence_transformer_ef, - ) - - collection.add( - documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] - ) - return True - except Exception as e: - print(e) - if e.__class__.__name__ == "UniqueConstraintError": - return True - - return False - - @app.get("/") async def get_status(): return { @@ -325,6 +292,56 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ) +def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.CHUNK_SIZE, + chunk_overlap=app.state.CHUNK_OVERLAP, + add_start_index=True, + ) + docs = text_splitter.split_documents(data) + return store_docs_in_vector_db(docs, collection_name, overwrite) + + +def store_text_in_vector_db( + text, name, collection_name, overwrite: bool = False +) -> bool: + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=app.state.CHUNK_SIZE, + chunk_overlap=app.state.CHUNK_OVERLAP, + add_start_index=True, + ) + docs = text_splitter.create_documents([text], metadatas=[{"name": name}]) + return store_docs_in_vector_db(docs, collection_name, overwrite) + + +def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: + texts = [doc.page_content for doc in docs] + metadatas = [doc.metadata for doc in docs] + + try: + if overwrite: + for collection in CHROMA_CLIENT.list_collections(): + if collection_name == collection.name: + print(f"deleting existing collection {collection_name}") + CHROMA_CLIENT.delete_collection(name=collection_name) + + collection = CHROMA_CLIENT.create_collection( + name=collection_name, + embedding_function=app.state.sentence_transformer_ef, + ) + + collection.add( + documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] + ) + return True + except Exception as e: + print(e) + if e.__class__.__name__ == "UniqueConstraintError": + return True + + return False + + def get_loader(filename: str, file_content_type: str, file_path: str): file_ext = filename.split(".")[-1].lower() known_type = True @@ -460,6 +477,33 @@ def store_doc( ) +class TextRAGForm(BaseModel): + name: str + content: str + collection_name: Optional[str] = None + + +@app.post("/text") +def store_text( + form_data: TextRAGForm, + user=Depends(get_current_user), +): + + collection_name = form_data.collection_name + if collection_name == None: + collection_name = calculate_sha256_string(form_data.content) + + result = store_text_in_vector_db(form_data.content, form_data.name, collection_name) + + if result: + return {"status": True, "collection_name": collection_name} + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + @app.get("/scan") def scan_docs_dir(user=Depends(get_admin_user)): for path in Path(DOCS_DIR).rglob("./**/*"): diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index a3537d4d3..c5e0e8a8b 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -137,6 +137,8 @@ def rag_messages(docs, messages, template, k, embedding_function): k=k, embedding_function=embedding_function, ) + elif doc["type"] == "text": + context = doc["content"] else: context = query_doc( collection_name=doc["collection_name"],