From 17c772831df5766a798f9e2769e0ed1e8bd73df2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 4 Oct 2024 00:23:14 -0700 Subject: [PATCH] refac --- backend/open_webui/apps/retrieval/main.py | 76 ++++++++++++++----- .../open_webui/apps/webui/routers/files.py | 63 +++++++++++---- 2 files changed, 107 insertions(+), 32 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index cd27b5530..f64cad3bb 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -700,17 +700,25 @@ def save_docs_to_vector_db( list(map(lambda x: x.replace("\n", " "), texts)) ) + items = [ + { + "id": str(uuid.uuid4()), + "text": text, + "vector": embeddings[idx], + "metadata": metadatas[idx], + } + for idx, text in enumerate(texts) + ] + + if file_id: + VECTOR_DB_CLIENT.insert( + collection_name=f"file-{file_id}", + items=items, + ) + VECTOR_DB_CLIENT.insert( collection_name=collection_name, - items=[ - { - "id": str(uuid.uuid4()), - "text": text, - "vector": embeddings[idx], - "metadata": metadatas[idx], - } - for idx, text in enumerate(texts) - ], + items=items, ) return True @@ -721,6 +729,7 @@ def save_docs_to_vector_db( class ProcessFileForm(BaseModel): file_id: str + content: Optional[str] = None collection_name: Optional[str] = None @@ -742,29 +751,58 @@ def process_file( PDF_EXTRACT_IMAGES=app.state.config.PDF_EXTRACT_IMAGES, ) - file_path = file.meta.get("path", None) - if file_path: - docs = loader.load(file.filename, file.meta.get("content_type"), file_path) - else: + if form_data.content: docs = [ Document( - page_content=file.data.get("content", ""), + page_content=form_data.content, metadata={ - "name": file.filename, + "name": file.meta.get("name", file.filename), "created_by": file.user_id, **file.meta, }, ) ] - text_content = " ".join([doc.page_content for doc in docs]) - log.debug(f"text_content: {text_content}") - hash = calculate_sha256_string(text_content) + text_content = form_data.content + elif file.data.get("content", None): + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + "name": file.meta.get("name", file.filename), + "created_by": file.user_id, + **file.meta, + }, + ) + ] + text_content = file.data.get("content", "") + else: + file_path = file.meta.get("path", None) + if file_path: + docs = loader.load( + file.filename, file.meta.get("content_type"), file_path + ) + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + "name": file.filename, + "created_by": file.user_id, + **file.meta, + }, + ) + ] + text_content = " ".join([doc.page_content for doc in docs]) + + log.debug(f"text_content: {text_content}") Files.update_file_data_by_id( file.id, {"content": text_content}, ) + + hash = calculate_sha256_string(text_content) Files.update_file_hash_by_id(file.id, hash) try: @@ -772,7 +810,7 @@ def process_file( docs=docs, collection_name=collection_name, metadata={ - "file_id": form_data.file_id, + "file_id": file.id, "name": file.meta.get("name", file.filename), "hash": hash, }, diff --git a/backend/open_webui/apps/webui/routers/files.py b/backend/open_webui/apps/webui/routers/files.py index 17c656be5..4d688b1ba 100644 --- a/backend/open_webui/apps/webui/routers/files.py +++ b/backend/open_webui/apps/webui/routers/files.py @@ -4,6 +4,7 @@ import shutil import uuid from pathlib import Path from typing import Optional +from pydantic import BaseModel from open_webui.apps.webui.models.files import FileForm, FileModel, Files from open_webui.apps.retrieval.main import process_file, ProcessFileForm @@ -154,6 +155,55 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): ) +############################ +# Get File Data Content By Id +############################ + + +@router.get("/{id}/data/content") +async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) + + if file and (file.user_id == user.id or user.role == "admin"): + return {"content": file.data.get("content", "")} + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# Update File Data Content By Id +############################ + + +class ContentForm(BaseModel): + content: str + + +@router.post("/{id}/data/content/update") +async def update_file_data_content_by_id( + id: str, form_data: ContentForm, user=Depends(get_verified_user) +): + file = Files.get_file_by_id(id) + + if file and (file.user_id == user.id or user.role == "admin"): + try: + process_file(ProcessFileForm(file_id=id, content=form_data.content)) + file = Files.get_file_by_id(id=id) + except Exception as e: + log.exception(e) + log.error(f"Error processing file: {file.id}") + + return {"content": file.data.get("content", "")} + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # Get File Content By Id ############################ @@ -182,19 +232,6 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ) -@router.get("/{id}/content/text") -async def get_file_text_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) - - if file and (file.user_id == user.id or user.role == "admin"): - return {"text": file.data.get("content")} - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - @router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): file = Files.get_file_by_id(id)