Merge pull request #7881 from gabriel-ecegi/dev

feat: Batch Processing for Large-Scale Document Import
This commit is contained in:
Timothy Jaeryang Baek 2024-12-17 13:54:00 -08:00 committed by GitHub
commit 9abae36264
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 181 additions and 6 deletions

View File

@ -1,5 +1,4 @@
import json
from typing import Optional, Union
from typing import List, Optional
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
@ -12,11 +11,11 @@ from open_webui.models.knowledge import (
)
from open_webui.models.files import Files, FileModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import process_file, ProcessFileForm
from open_webui.routers.retrieval import process_file, ProcessFileForm, process_files_batch, BatchProcessFilesForm
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.utils.access_control import has_access, has_permission
@ -514,3 +513,85 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge
############################
# AddFilesToKnowledge
############################
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
def add_files_to_knowledge_batch(
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File {form.file_id} not found",
)
files.append(file)
# Process files
try:
result = process_files_batch(BatchProcessFilesForm(
files=files,
collection_name=id
))
except Exception as e:
log.error(f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
# Add successful files to knowledge base
data = knowledge.data or {}
existing_file_ids = data.get("file_ids", [])
# Only add files that were successfully processed
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
if file_id not in existing_file_ids:
existing_file_ids.append(file_id)
data["file_ids"] = existing_file_ids
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
# If there were any errors, include them in the response
if result.errors:
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details
}
)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids)
)

View File

@ -7,7 +7,7 @@ import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from typing import Iterator, List, Optional, Sequence, Union
from fastapi import (
Depends,
@ -28,7 +28,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_core.documents import Document
from open_webui.models.files import Files
from open_webui.models.files import FileModel, Files
from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage
@ -1428,3 +1428,97 @@ if ENV == "dev":
@router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
class BatchProcessFilesForm(BaseModel):
files: List[FileModel]
collection_name: str
class BatchProcessFilesResult(BaseModel):
file_id: str
status: str
error: Optional[str] = None
class BatchProcessFilesResponse(BaseModel):
results: List[BatchProcessFilesResult]
errors: List[BatchProcessFilesResult]
@router.post("/process/files/batch")
def process_files_batch(
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
"""
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name
# Prepare all documents first
all_docs: List[Document] = []
for file in form_data.files:
try:
text_content = file.data.get("content", "")
docs: List[Document] = [
Document(
page_content=text_content.replace("<br/>", "\n"),
metadata={
**file.meta,
"name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
"source": file.filename,
},
)
]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs)
results.append(BatchProcessFilesResult(
file_id=file.id,
status="prepared"
))
except Exception as e:
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
errors.append(BatchProcessFilesResult(
file_id=file.id,
status="failed",
error=str(e)
))
# Save all documents in one batch
if all_docs:
try:
save_docs_to_vector_db(
docs=all_docs,
collection_name=collection_name,
add=True
)
# Update all files with collection name
for result in results:
Files.update_file_metadata_by_id(
result.file_id,
{"collection_name": collection_name}
)
result.status = "completed"
except Exception as e:
log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
for result in results:
result.status = "failed"
errors.append(BatchProcessFilesResult(
file_id=result.file_id,
error=str(e)
))
return BatchProcessFilesResponse(
results=results,
errors=errors
)