From b291271df3e7b9d54934005de169d500e8759bb0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 3 Oct 2024 22:22:22 -0700 Subject: [PATCH] refac --- backend/open_webui/apps/retrieval/main.py | 14 ++-- backend/open_webui/apps/retrieval/utils.py | 2 +- .../open_webui/apps/webui/routers/files.py | 10 ++- .../apps/webui/routers/knowledge.py | 10 ++- backend/open_webui/constants.py | 1 + src/lib/apis/knowledge/index.ts | 70 +++++++++++++++++++ .../components/chat/Controls/Controls.svelte | 2 +- src/lib/components/chat/MessageInput.svelte | 30 ++------ .../chat/Messages/UserMessage.svelte | 2 +- src/lib/components/common/FileItem.svelte | 8 +-- .../components/common/FileItemModal.svelte | 22 +++--- .../workspace/Knowledge/Collection.svelte | 60 ++++++++-------- 12 files changed, 152 insertions(+), 79 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index 7b45ccff5..c9ba33211 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -731,7 +731,7 @@ def process_file( collection_name = form_data.collection_name if collection_name is None: - collection_name = file.id + collection_name = f"file-{file.id}" loader = Loader( engine=app.state.config.CONTENT_EXTRACTION_ENGINE, @@ -758,12 +758,11 @@ def process_file( log.debug(f"text_content: {text_content}") hash = calculate_sha256_string(text_content) - res = Files.update_file_data_by_id( + Files.update_file_data_by_id( file.id, {"content": text_content}, ) - print(res) - Files.update_file_hash_by_id(form_data.file_id, hash) + Files.update_file_hash_by_id(file.id, hash) try: result = save_docs_to_vector_db( @@ -778,6 +777,13 @@ def process_file( ) if result: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + ) + return { "status": True, "collection_name": collection_name, diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 12c30edbb..c671b03b4 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -319,7 +319,7 @@ def get_rag_context( for file in files: if file.get("context") == "full": context = { - "documents": [[file.get("file").get("content")]], + "documents": [[file.get("file").get("data", {}).get("content")]], "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], } else: diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index 1204dea9d..17c656be5 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -6,7 +6,8 @@ from pathlib import Path from typing import Optional from open_webui.apps.webui.models.files import FileForm, FileModel, Files -from open_webui.apps.webui.models.knowledge import Knowledges +from open_webui.apps.retrieval.main import process_file, ProcessFileForm + from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS @@ -61,6 +62,13 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ), ) + try: + process_file(ProcessFileForm(file_id=id)) + file = Files.get_file_by_id(id=id) + except Exception as e: + log.exception(e) + log.error(f"Error processing file: {file.id}") + if file: return file else: diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 29316258d..88ca8c398 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -17,7 +17,6 @@ from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT - router = APIRouter() ############################ @@ -132,7 +131,7 @@ class KnowledgeFileIdForm(BaseModel): @router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) -async def add_file_to_knowledge_by_id( +def add_file_to_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_admin_user), @@ -144,6 +143,11 @@ async def add_file_to_knowledge_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.NOT_FOUND, ) + if not file.data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.FILE_NOT_PROCESSED, + ) if knowledge: data = knowledge.data or {} @@ -191,7 +195,7 @@ class KnowledgeFileIdForm(BaseModel): @router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse]) -async def remove_file_from_knowledge_by_id( +def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, user=Depends(get_admin_user), diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index e8c456b9e..0326ae96e 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -95,6 +95,7 @@ class ERROR_MESSAGES(str, Enum): ) DUPLICATE_CONTENT = "The content provided is a duplicate. Please ensure that the content is unique before proceeding." + FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding." class TASKS(str, Enum): diff --git a/src/lib/apis/knowledge/index.ts b/src/lib/apis/knowledge/index.ts index 511924a04..a0ba83a0e 100644 --- a/src/lib/apis/knowledge/index.ts +++ b/src/lib/apis/knowledge/index.ts @@ -138,6 +138,76 @@ export const updateKnowledgeById = async (token: string, id: string, form: Knowl return res; }; +export const addFileToKnowledgeById = async (token: string, id: string, fileId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/${id}/file/add`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + file_id: fileId + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const removeFileFromKnowledgeById = async (token: string, id: string, fileId: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/knowledge/${id}/file/remove`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + file_id: fileId + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteKnowledgeById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/chat/Controls/Controls.svelte b/src/lib/components/chat/Controls/Controls.svelte index f5807b9b8..01118e8a1 100644 --- a/src/lib/components/chat/Controls/Controls.svelte +++ b/src/lib/components/chat/Controls/Controls.svelte @@ -35,7 +35,7 @@ {#each chatFiles as file, fileIdx} item.status !== null); } @@ -143,27 +144,6 @@ files = files.filter((item) => item.status !== null); } }; - - const processFileItem = async (fileItem) => { - try { - const res = await processFile(localStorage.token, fileItem.id); - if (res) { - fileItem.status = 'processed'; - fileItem.collection_name = res.collection_name; - fileItem.file = { - ...fileItem.file, - content: res.content - }; - - files = files; - } - } catch (e) { - // We keep the file in the files list even if it fails to process - fileItem.status = 'processed'; - files = files; - } - }; - const inputFilesHandler = async (inputFiles) => { inputFiles.forEach((file) => { console.log(file, file.name.split('.').at(-1)); @@ -456,7 +436,7 @@ {:else} {:else} -{#if file} - +{#if item} + {/if}