From f2e2b59c181a669a113dcf7f646aafc13defbc44 Mon Sep 17 00:00:00 2001 From: Gabriel Ecegi Date: Fri, 13 Dec 2024 15:29:43 +0100 Subject: [PATCH] Add batching --- backend/open_webui/apps/retrieval/main.py | 114 ++++++++++++++++-- .../apps/webui/routers/knowledge.py | 82 ++++++++++++- 2 files changed, 182 insertions(+), 14 deletions(-) diff --git a/backend/open_webui/apps/retrieval/main.py b/backend/open_webui/apps/retrieval/main.py index cfbc5beee..86ea6bf41 100644 --- a/backend/open_webui/apps/retrieval/main.py +++ b/backend/open_webui/apps/retrieval/main.py @@ -2,16 +2,14 @@ import json import logging -import mimetypes import os import shutil import uuid from datetime import datetime -from pathlib import Path -from typing import Iterator, Optional, Sequence, Union +from typing import List, Optional -from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status +from fastapi import Depends, FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import tiktoken @@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import ( query_doc_with_hybrid_search, ) -from open_webui.apps.webui.models.files import Files +from open_webui.apps.webui.models.files import FileModel, Files from open_webui.config import ( BRAVE_SEARCH_API_KEY, KAGI_SEARCH_API_KEY, @@ -64,7 +62,6 @@ from open_webui.config import ( CONTENT_EXTRACTION_ENGINE, CORS_ALLOW_ORIGIN, ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, ENABLE_RAG_WEB_SEARCH, ENV, @@ -86,7 +83,6 @@ from open_webui.config import ( RAG_RERANKING_MODEL, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - DEFAULT_RAG_TEMPLATE, RAG_TEMPLATE, RAG_TOP_K, RAG_WEB_SEARCH_CONCURRENT_REQUESTS, @@ -118,10 +114,7 @@ from open_webui.env import ( DOCKER, ) from open_webui.utils.misc import ( - calculate_sha256, calculate_sha256_string, - extract_folders_after_data_docs, - sanitize_filename, ) from open_webui.utils.auth import get_admin_user, get_verified_user @@ -1047,6 +1040,106 @@ def process_file( ) +class BatchProcessFilesForm(BaseModel): + files: List[FileModel] + collection_name: str + +class BatchProcessFilesResult(BaseModel): + file_id: str + status: str + error: Optional[str] = None + +class BatchProcessFilesResponse(BaseModel): + results: List[BatchProcessFilesResult] + errors: List[BatchProcessFilesResult] + +@app.post("/process/files/batch") +def process_files_batch( + form_data: BatchProcessFilesForm, + user=Depends(get_verified_user), +) -> BatchProcessFilesResponse: + """ + Process a batch of files and save them to the vector database. + """ + results: List[BatchProcessFilesResult] = [] + errors: List[BatchProcessFilesResult] = [] + collection_name = form_data.collection_name + + + # Prepare all documents first + all_docs: List[Document] = [] + for file_request in form_data.files: + try: + file = Files.get_file_by_id(file_request.file_id) + if not file: + log.error(f"process_files_batch: File {file_request.file_id} not found") + raise ValueError(f"File {file_request.file_id} not found") + + text_content = file_request.content + + docs: List[Document] = [ + Document( + page_content=text_content.replace("
", "\n"), + metadata={ + **file.meta, + "name": file_request.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file_request.filename, + }, + ) + ] + + hash = calculate_sha256_string(text_content) + Files.update_file_hash_by_id(file.id, hash) + Files.update_file_data_by_id(file.id, {"content": text_content}) + + all_docs.extend(docs) + results.append(BatchProcessFilesResult( + file_id=file.id, + status="prepared" + )) + + except Exception as e: + log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}") + errors.append(BatchProcessFilesResult( + file_id=file_request.file_id, + status="failed", + error=str(e) + )) + + # Save all documents in one batch + if all_docs: + try: + save_docs_to_vector_db( + docs=all_docs, + collection_name=collection_name, + add=True + ) + + # Update all files with collection name + for result in results: + Files.update_file_metadata_by_id( + result.file_id, + {"collection_name": collection_name} + ) + result.status = "completed" + + except Exception as e: + log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}") + for result in results: + result.status = "failed" + errors.append(BatchProcessFilesResult( + file_id=result.file_id, + error=str(e) + )) + + return BatchProcessFilesResponse( + results=results, + errors=errors + ) + + class ProcessTextForm(BaseModel): name: str content: str @@ -1509,3 +1602,4 @@ if ENV == "dev": @app.get("/ef/{text}") async def get_embeddings_text(text: str): return {"result": app.state.EMBEDDING_FUNCTION(text)} + diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index d572e83b7..ccc2251d1 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -1,5 +1,4 @@ -import json -from typing import Optional, Union +from typing import List, Optional from pydantic import BaseModel from fastapi import APIRouter, Depends, HTTPException, status, Request import logging @@ -12,11 +11,11 @@ from open_webui.apps.webui.models.knowledge import ( ) from open_webui.apps.webui.models.files import Files, FileModel from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.apps.retrieval.main import process_file, ProcessFileForm +from open_webui.apps.retrieval.main import BatchProcessFilesForm, process_file, ProcessFileForm, process_files_batch from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.auth import get_verified_user from open_webui.utils.access_control import has_access, has_permission @@ -508,3 +507,78 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []}) return knowledge + + +############################ +# AddFilesToKnowledge +############################ + +@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse]) +def add_files_to_knowledge_batch( + id: str, + form_data: list[KnowledgeFileIdForm], + user=Depends(get_verified_user), +): + """ + Add multiple files to a knowledge base + """ + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + # Get files content + print(f"files/batch/add - {len(form_data)} files") + files: List[FileModel] = [] + for form in form_data: + file = Files.get_file_by_id(form.file_id) + if not file: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File {form.file_id} not found", + ) + files.append(file) + + # Process files + result = process_files_batch(BatchProcessFilesForm( + files=files, + collection_name=id + )) + + # Add successful files to knowledge base + data = knowledge.data or {} + existing_file_ids = data.get("file_ids", []) + + # Only add files that were successfully processed + successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] + for file_id in successful_file_ids: + if file_id not in existing_file_ids: + existing_file_ids.append(file_id) + + data["file_ids"] = existing_file_ids + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) + + # If there were any errors, include them in the response + if result.errors: + error_details = [f"{err.file_id}: {err.error}" for err in result.errors] + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=Files.get_files_by_ids(existing_file_ids), + warnings={ + "message": "Some files failed to process", + "errors": error_details + } + ) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=Files.get_files_by_ids(existing_file_ids) + )