From 9e7b7a895e77149fac9c99338f5a7b51549c8ac3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 18 Jun 2024 13:50:18 -0700 Subject: [PATCH] refac: file upload --- backend/apps/rag/main.py | 54 +++++++++ backend/apps/webui/routers/files.py | 13 +- src/lib/apis/files/index.ts | 125 ++++++++++++++++++++ src/lib/apis/rag/index.ts | 30 +++++ src/lib/components/chat/MessageInput.svelte | 124 +++++++++---------- 5 files changed, 285 insertions(+), 61 deletions(-) create mode 100644 src/lib/apis/files/index.ts diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 49146a215..3bd7303bd 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -55,6 +55,9 @@ from apps.webui.models.documents import ( DocumentForm, DocumentResponse, ) +from apps.webui.models.files import ( + Files, +) from apps.rag.utils import ( get_model_path, @@ -1131,6 +1134,57 @@ def store_doc( ) +class ProcessDocForm(BaseModel): + file_id: str + + +@app.post("/process/doc") +def process_doc( + form_data: ProcessDocForm, + user=Depends(get_current_user), +): + try: + file = Files.get_file_by_id(form_data.file_id) + file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}") + + f = open(file_path, "rb") + if collection_name == None: + collection_name = calculate_sha256(f)[:63] + f.close() + + loader, known_type = get_loader( + file.filename, file.meta.get("content_type"), file_path + ) + data = loader.load() + + try: + result = store_data_in_vector_db(data, collection_name) + + if result: + return { + "status": True, + "collection_name": collection_name, + "known_type": known_type, + } + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=e, + ) + except Exception as e: + log.exception(e) + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + class TextRAGForm(BaseModel): name: str content: str diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index 773386059..a231f7bb1 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -61,7 +61,18 @@ def upload_file( f.close() file = Files.insert_new_file( - user.id, FileForm(**{"id": id, "filename": filename}) + user.id, + FileForm( + **{ + "id": id, + "filename": filename, + "meta": { + "content_type": file.content_type, + "size": len(contents), + "path": file_path, + }, + } + ), ) if file: diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts new file mode 100644 index 000000000..30222b5d5 --- /dev/null +++ b/src/lib/apis/files/index.ts @@ -0,0 +1,125 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const uploadFile = async (token: string, file: File) => { + const data = new FormData(); + data.append('file', file); + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: data + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFiles = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFileById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteFileById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ca68827a3..5639830c1 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -164,6 +164,36 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings return res; }; +export const processDocToVectorDB = async (token: string, file_id: string) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/process/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + file_id: file_id + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { const data = new FormData(); data.append('file', file); diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 7f0beb9c7..f6bc595b2 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -15,10 +15,13 @@ import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils'; import { + processDocToVectorDB, uploadDocToVectorDB, uploadWebToVectorDB, uploadYoutubeTranscriptionToVectorDB } from '$lib/apis/rag'; + + import { uploadFile } from '$lib/apis/files'; import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS, WEBUI_BASE_URL } from '$lib/constants'; import Prompts from './MessageInput/PromptCommands.svelte'; @@ -86,43 +89,70 @@ element.scrollTop = element.scrollHeight; }; - const uploadDoc = async (file) => { + const uploadFileHandler = async (file) => { console.log(file); - - const doc = { - type: 'doc', - name: file.name, - collection_name: '', - upload_status: false, - error: '' - }; - - try { - files = [...files, doc]; - - if (['audio/mpeg', 'audio/wav'].includes(file['type'])) { - const res = await transcribeAudio(localStorage.token, file).catch((error) => { - toast.error(error); - return null; - }); - - if (res) { - console.log(res); - const blob = new Blob([res.text], { type: 'text/plain' }); - file = blobToFile(blob, `${file.name}.txt`); - } - } - - const res = await uploadDocToVectorDB(localStorage.token, '', file); + // Check if the file is an audio file and transcribe/convert it to text file + if (['audio/mpeg', 'audio/wav'].includes(file['type'])) { + const res = await transcribeAudio(localStorage.token, file).catch((error) => { + toast.error(error); + return null; + }); if (res) { - doc.upload_status = true; - doc.collection_name = res.collection_name; + console.log(res); + const blob = new Blob([res.text], { type: 'text/plain' }); + file = blobToFile(blob, `${file.name}.txt`); + } + } + + // Upload the file to the server + const uploadedFile = await uploadFile(localStorage.token, file).catch((error) => { + toast.error(error); + return null; + }); + + if (uploadedFile) { + const fileItem = { + type: 'file', + file: uploadedFile, + id: uploadedFile.id, + name: file.name, + collection_name: '', + status: 'uploaded', + error: '' + }; + files = [...files, fileItem]; + + // TODO: Check if tools & functions have files support to skip this step to delegate file processing + // Default Upload to VectorDB + if ( + SUPPORTED_FILE_TYPE.includes(file['type']) || + SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) + ) { + processFileItem(fileItem); + } else { + toast.error( + $i18n.t(`Unknown File Type '{{file_type}}', but accepting and treating as plain text`, { + file_type: file['type'] + }) + ); + processFileItem(fileItem); + } + } + }; + + const processFileItem = async (fileItem) => { + try { + const res = await processDocToVectorDB(localStorage.token, fileItem.id); + + if (res) { + fileItem.status = 'processed'; + fileItem.collection_name = res.collection_name; files = files; } } catch (e) { // Remove the failed doc from the files array - files = files.filter((f) => f.name !== file.name); + files = files.filter((f) => f.id !== fileItem.id); toast.error(e); } }; @@ -230,19 +260,8 @@ ]; }; reader.readAsDataURL(file); - } else if ( - SUPPORTED_FILE_TYPE.includes(file['type']) || - SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) - ) { - uploadDoc(file); } else { - toast.error( - $i18n.t( - `Unknown File Type '{{file_type}}', but accepting and treating as plain text`, - { file_type: file['type'] } - ) - ); - uploadDoc(file); + uploadFileHandler(file); } }); } else { @@ -409,8 +428,6 @@ if (['image/gif', 'image/webp', 'image/jpeg', 'image/png'].includes(file['type'])) { if (visionCapableModels.length === 0) { toast.error($i18n.t('Selected model(s) do not support image inputs')); - inputFiles = null; - filesInputElement.value = ''; return; } let reader = new FileReader(); @@ -422,30 +439,17 @@ url: `${event.target.result}` } ]; - inputFiles = null; - filesInputElement.value = ''; }; reader.readAsDataURL(file); - } else if ( - SUPPORTED_FILE_TYPE.includes(file['type']) || - SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) - ) { - uploadDoc(file); - filesInputElement.value = ''; } else { - toast.error( - $i18n.t( - `Unknown File Type '{{file_type}}', but accepting and treating as plain text`, - { file_type: file['type'] } - ) - ); - uploadDoc(file); - filesInputElement.value = ''; + uploadFileHandler(file); } }); } else { toast.error($i18n.t(`File not found.`)); } + + filesInputElement.value = ''; }} />