diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 7f9947d7a..0dff2bc02 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -1,5 +1,4 @@ -import json -from typing import Optional, Union +from typing import List, Optional from pydantic import BaseModel from fastapi import APIRouter, Depends, HTTPException, status, Request import logging @@ -12,11 +11,11 @@ from open_webui.models.knowledge import ( ) from open_webui.models.files import Files, FileModel from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT -from open_webui.routers.retrieval import process_file, ProcessFileForm +from open_webui.routers.retrieval import process_file, ProcessFileForm, process_files_batch, BatchProcessFilesForm from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.auth import get_verified_user from open_webui.utils.access_control import has_access, has_permission @@ -514,3 +513,85 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []}) return knowledge + + +############################ +# AddFilesToKnowledge +############################ + +@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse]) +def add_files_to_knowledge_batch( + id: str, + form_data: list[KnowledgeFileIdForm], + user=Depends(get_verified_user), +): + """ + Add multiple files to a knowledge base + """ + knowledge = Knowledges.get_knowledge_by_id(id=id) + if not knowledge: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + # Get files content + print(f"files/batch/add - {len(form_data)} files") + files: List[FileModel] = [] + for form in form_data: + file = Files.get_file_by_id(form.file_id) + if not file: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"File {form.file_id} not found", + ) + files.append(file) + + # Process files + try: + result = process_files_batch(BatchProcessFilesForm( + files=files, + collection_name=id + )) + except Exception as e: + log.error(f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + + # Add successful files to knowledge base + data = knowledge.data or {} + existing_file_ids = data.get("file_ids", []) + + # Only add files that were successfully processed + successful_file_ids = [r.file_id for r in result.results if r.status == "completed"] + for file_id in successful_file_ids: + if file_id not in existing_file_ids: + existing_file_ids.append(file_id) + + data["file_ids"] = existing_file_ids + knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data) + + # If there were any errors, include them in the response + if result.errors: + error_details = [f"{err.file_id}: {err.error}" for err in result.errors] + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=Files.get_files_by_ids(existing_file_ids), + warnings={ + "message": "Some files failed to process", + "errors": error_details + } + ) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=Files.get_files_by_ids(existing_file_ids) + ) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index e577f70f1..1898bfe49 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -7,7 +7,7 @@ import shutil import uuid from datetime import datetime from pathlib import Path -from typing import Iterator, Optional, Sequence, Union +from typing import Iterator, List, Optional, Sequence, Union from fastapi import ( Depends, @@ -28,7 +28,7 @@ import tiktoken from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain_core.documents import Document -from open_webui.models.files import Files +from open_webui.models.files import FileModel, Files from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage @@ -1428,3 +1428,97 @@ if ENV == "dev": @router.get("/ef/{text}") async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): return {"result": request.app.state.EMBEDDING_FUNCTION(text)} + +class BatchProcessFilesForm(BaseModel): + files: List[FileModel] + collection_name: str + +class BatchProcessFilesResult(BaseModel): + file_id: str + status: str + error: Optional[str] = None + +class BatchProcessFilesResponse(BaseModel): + results: List[BatchProcessFilesResult] + errors: List[BatchProcessFilesResult] + +@router.post("/process/files/batch") +def process_files_batch( + form_data: BatchProcessFilesForm, + user=Depends(get_verified_user), +) -> BatchProcessFilesResponse: + """ + Process a batch of files and save them to the vector database. + """ + results: List[BatchProcessFilesResult] = [] + errors: List[BatchProcessFilesResult] = [] + collection_name = form_data.collection_name + + # Prepare all documents first + all_docs: List[Document] = [] + for file in form_data.files: + try: + text_content = file.data.get("content", "") + + docs: List[Document] = [ + Document( + page_content=text_content.replace("
", "\n"), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + + hash = calculate_sha256_string(text_content) + Files.update_file_hash_by_id(file.id, hash) + Files.update_file_data_by_id(file.id, {"content": text_content}) + + all_docs.extend(docs) + results.append(BatchProcessFilesResult( + file_id=file.id, + status="prepared" + )) + + except Exception as e: + log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}") + errors.append(BatchProcessFilesResult( + file_id=file.id, + status="failed", + error=str(e) + )) + + # Save all documents in one batch + if all_docs: + try: + save_docs_to_vector_db( + docs=all_docs, + collection_name=collection_name, + add=True + ) + + # Update all files with collection name + for result in results: + Files.update_file_metadata_by_id( + result.file_id, + {"collection_name": collection_name} + ) + result.status = "completed" + + except Exception as e: + log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}") + for result in results: + result.status = "failed" + errors.append(BatchProcessFilesResult( + file_id=result.file_id, + error=str(e) + )) + + return BatchProcessFilesResponse( + results=results, + errors=errors + ) +