From 4e468dc58cd71392d92539230743bde4256598bc Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 25 Jan 2024 00:24:49 -0800 Subject: [PATCH] refac --- backend/apps/rag/main.py | 125 ++++++++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 42 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index ffa73d000..6da870ea7 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -138,6 +138,87 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ) +def get_loader(file, file_path): + file_ext = file.filename.split(".")[-1].lower() + known_type = True + + known_source_ext = [ + "go", + "py", + "java", + "sh", + "bat", + "ps1", + "cmd", + "js", + "ts", + "css", + "cpp", + "hpp", + "h", + "c", + "cs", + "sql", + "log", + "ini", + "pl", + "pm", + "r", + "dart", + "dockerfile", + "env", + "php", + "hs", + "hsc", + "lua", + "nginxconf", + "conf", + "m", + "mm", + "plsql", + "perl", + "rb", + "rs", + "db2", + "scala", + "bash", + "swift", + "vue", + "svelte", + ] + + if file_ext == "pdf": + loader = PyPDFLoader(file_path) + elif file_ext == "csv": + loader = CSVLoader(file_path) + elif file_ext == "rst": + loader = UnstructuredRSTLoader(file_path, mode="elements") + elif file_ext == "xml": + loader = UnstructuredXMLLoader(file_path) + elif file_ext == "md": + loader = UnstructuredMarkdownLoader(file_path) + elif file.content_type == "application/epub+zip": + loader = UnstructuredEPubLoader(file_path) + elif ( + file.content_type + == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + or file_ext in ["doc", "docx"] + ): + loader = Docx2txtLoader(file_path) + elif file.content_type in [ + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ] or file_ext in ["xls", "xlsx"]: + loader = UnstructuredExcelLoader(file_path) + elif file_ext in known_source_ext or file.content_type.find("text/") >= 0: + loader = TextLoader(file_path) + else: + loader = TextLoader(file_path) + known_type = False + + return loader, known_type + + @app.post("/doc") def store_doc( collection_name: Optional[str] = Form(None), @@ -147,24 +228,6 @@ def store_doc( # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" print(file.content_type) - - text_xml=["xml"] - octet_markdown=["md"] - known_source_ext=[ - "go", "py", "java", "sh", "bat", "ps1", "cmd", "js", "ts", - "css", "cpp", "hpp","h", "c", "cs", "sql", "log", "ini", - "pl", "pm", "r", "dart", "dockerfile", "env", "php", "hs", - "hsc", "lua", "nginxconf", "conf", "m", "mm", "plsql", "perl", - "rb", "rs", "db2", "scala", "bash", "swift", "vue", "svelte" - ] - docx_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document" - known_doc_ext=["doc","docx"] - excel_types=["application/vnd.ms-excel", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"] - known_excel_ext=["xls", "xlsx"] - - file_ext=file.filename.split(".")[-1].lower() - known_type=True - try: filename = file.filename file_path = f"{UPLOAD_DIR}/{filename}" @@ -178,29 +241,7 @@ def store_doc( collection_name = calculate_sha256(f)[:63] f.close() - if file_ext=="pdf": - loader = PyPDFLoader(file_path) - elif (file.content_type ==docx_type or file_ext in known_doc_ext): - loader = Docx2txtLoader(file_path) - elif file_ext=="csv": - loader = CSVLoader(file_path) - elif (file.content_type in excel_types or file_ext in known_excel_ext): - loader = UnstructuredExcelLoader(file_path) - elif file_ext=="rst": - loader = UnstructuredRSTLoader(file_path, mode="elements") - elif file_ext in text_xml: - loader=UnstructuredXMLLoader(file_path) - elif file_ext in known_source_ext or file.content_type.find("text/")>=0: - loader = TextLoader(file_path) - elif file_ext in octet_markdown: - loader = UnstructuredMarkdownLoader(file_path) - elif file.content_type == "application/epub+zip": - loader = UnstructuredEPubLoader(file_path) - else: - loader = TextLoader(file_path) - known_type=False - - + loader, known_type = get_loader(file, file_path) data = loader.load() result = store_data_in_vector_db(data, collection_name) @@ -209,7 +250,7 @@ def store_doc( "status": True, "collection_name": collection_name, "filename": filename, - "known_type":known_type, + "known_type": known_type, } else: raise HTTPException(