From dbc352f01b1393dd06f5209552a59aa0affad0eb Mon Sep 17 00:00:00 2001
From: "Timothy J. Baek" <timothyjrbeck@gmail.com>
Date: Mon, 15 Jul 2024 13:05:38 +0200
Subject: [PATCH] refac: documents file handling

---
 backend/apps/rag/main.py                      | 24 +++++++++++----
 backend/apps/webui/routers/files.py           |  2 ++
 .../chat/Messages/CitationsModal.svelte       |  6 ++--
 src/lib/components/workspace/Documents.svelte | 29 +++++++++++++++++--
 4 files changed, 49 insertions(+), 12 deletions(-)

diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index c0f8a09ed..8631846ec 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -930,7 +930,9 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
         )
 
 
-def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
+def store_data_in_vector_db(
+    data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
+) -> bool:
 
     text_splitter = RecursiveCharacterTextSplitter(
         chunk_size=app.state.config.CHUNK_SIZE,
@@ -942,7 +944,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
 
     if len(docs) > 0:
         log.info(f"store_data_in_vector_db {docs}")
-        return store_docs_in_vector_db(docs, collection_name, overwrite), None
+        return store_docs_in_vector_db(docs, collection_name, metadata, overwrite), None
     else:
         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
 
@@ -956,14 +958,16 @@ def store_text_in_vector_db(
         add_start_index=True,
     )
     docs = text_splitter.create_documents([text], metadatas=[metadata])
-    return store_docs_in_vector_db(docs, collection_name, overwrite)
+    return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite)
 
 
-def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
+def store_docs_in_vector_db(
+    docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
+) -> bool:
     log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 
     texts = [doc.page_content for doc in docs]
-    metadatas = [doc.metadata for doc in docs]
+    metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
 
     # ChromaDB does not like datetime formats
     # for meta-data so convert them to string.
@@ -1237,13 +1241,21 @@ def process_doc(
         data = loader.load()
 
         try:
-            result = store_data_in_vector_db(data, collection_name)
+            result = store_data_in_vector_db(
+                data,
+                collection_name,
+                {
+                    "file_id": form_data.file_id,
+                    "name": file.meta.get("name", file.filename),
+                },
+            )
 
             if result:
                 return {
                     "status": True,
                     "collection_name": collection_name,
                     "known_type": known_type,
+                    "filename": file.meta.get("name", file.filename),
                 }
         except Exception as e:
             raise HTTPException(
diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py
index fffe0743c..99fb923a1 100644
--- a/backend/apps/webui/routers/files.py
+++ b/backend/apps/webui/routers/files.py
@@ -58,6 +58,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
 
         # replace filename with uuid
         id = str(uuid.uuid4())
+        name = filename
         filename = f"{id}_{filename}"
         file_path = f"{UPLOAD_DIR}/{filename}"
 
@@ -73,6 +74,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
                     "id": id,
                     "filename": filename,
                     "meta": {
+                        "name": name,
                         "content_type": file.content_type,
                         "size": len(contents),
                         "path": file_path,
diff --git a/src/lib/components/chat/Messages/CitationsModal.svelte b/src/lib/components/chat/Messages/CitationsModal.svelte
index b0913e336..47568431b 100644
--- a/src/lib/components/chat/Messages/CitationsModal.svelte
+++ b/src/lib/components/chat/Messages/CitationsModal.svelte
@@ -57,12 +57,12 @@
 						{#if document.source?.name}
 							<div class="text-sm dark:text-gray-400">
 								<a
-									href={document?.source?.url
-										? `${document?.source?.url}/content`
+									href={document?.metadata?.file_id
+										? `/api/v1/files/${document?.metadata?.file_id}/content`
 										: document.source.name}
 									target="_blank"
 								>
-									{document.source.name}
+									{document?.metadata?.name ?? document.source.name}
 								</a>
 							</div>
 						{:else}
diff --git a/src/lib/components/workspace/Documents.svelte b/src/lib/components/workspace/Documents.svelte
index acbf2db47..87873278c 100644
--- a/src/lib/components/workspace/Documents.svelte
+++ b/src/lib/components/workspace/Documents.svelte
@@ -8,14 +8,16 @@
 	import { createNewDoc, deleteDocByName, getDocs } from '$lib/apis/documents';
 
 	import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
-	import { uploadDocToVectorDB } from '$lib/apis/rag';
-	import { transformFileName } from '$lib/utils';
+	import { processDocToVectorDB, uploadDocToVectorDB } from '$lib/apis/rag';
+	import { blobToFile, transformFileName } from '$lib/utils';
 
 	import Checkbox from '$lib/components/common/Checkbox.svelte';
 
 	import EditDocModal from '$lib/components/documents/EditDocModal.svelte';
 	import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte';
 	import AddDocModal from '$lib/components/documents/AddDocModal.svelte';
+	import { transcribeAudio } from '$lib/apis/audio';
+	import { uploadFile } from '$lib/apis/files';
 
 	const i18n = getContext('i18n');
 
@@ -50,7 +52,28 @@
 	};
 
 	const uploadDoc = async (file) => {
-		const res = await uploadDocToVectorDB(localStorage.token, '', file).catch((error) => {
+		console.log(file);
+		// Check if the file is an audio file and transcribe/convert it to text file
+		if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
+			const transcribeRes = await transcribeAudio(localStorage.token, file).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (transcribeRes) {
+				console.log(transcribeRes);
+				const blob = new Blob([transcribeRes.text], { type: 'text/plain' });
+				file = blobToFile(blob, `${file.name}.txt`);
+			}
+		}
+
+		// Upload the file to the server
+		const uploadedFile = await uploadFile(localStorage.token, file).catch((error) => {
+			toast.error(error);
+			return null;
+		});
+
+		const res = await processDocToVectorDB(localStorage.token, uploadedFile.id).catch((error) => {
 			toast.error(error);
 			return null;
 		});