Merge pull request #7881 from gabriel-ecegi/dev

feat: Batch Processing for Large-Scale Document Import
This commit is contained in:
Timothy Jaeryang Baek 2024-12-17 13:54:00 -08:00 committed by GitHub
commit 9abae36264
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 181 additions and 6 deletions

View File

@ -1,5 +1,4 @@
import json from typing import List, Optional
from typing import Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging import logging
@ -12,11 +11,11 @@ from open_webui.models.knowledge import (
) )
from open_webui.models.files import Files, FileModel from open_webui.models.files import Files, FileModel
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.routers.retrieval import process_file, ProcessFileForm from open_webui.routers.retrieval import process_file, ProcessFileForm, process_files_batch, BatchProcessFilesForm
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
@ -514,3 +513,85 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []}) knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge return knowledge
############################
# AddFilesToKnowledge
############################
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
def add_files_to_knowledge_batch(
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File {form.file_id} not found",
)
files.append(file)
# Process files
try:
result = process_files_batch(BatchProcessFilesForm(
files=files,
collection_name=id
))
except Exception as e:
log.error(f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
# Add successful files to knowledge base
data = knowledge.data or {}
existing_file_ids = data.get("file_ids", [])
# Only add files that were successfully processed
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
if file_id not in existing_file_ids:
existing_file_ids.append(file_id)
data["file_ids"] = existing_file_ids
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
# If there were any errors, include them in the response
if result.errors:
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details
}
)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids)
)

View File

@ -7,7 +7,7 @@ import shutil
import uuid import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Iterator, Optional, Sequence, Union from typing import Iterator, List, Optional, Sequence, Union
from fastapi import ( from fastapi import (
Depends, Depends,
@ -28,7 +28,7 @@ import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_core.documents import Document from langchain_core.documents import Document
from open_webui.models.files import Files from open_webui.models.files import FileModel, Files
from open_webui.models.knowledge import Knowledges from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage from open_webui.storage.provider import Storage
@ -1428,3 +1428,97 @@ if ENV == "dev":
@router.get("/ef/{text}") @router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return {"result": request.app.state.EMBEDDING_FUNCTION(text)} return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
class BatchProcessFilesForm(BaseModel):
files: List[FileModel]
collection_name: str
class BatchProcessFilesResult(BaseModel):
file_id: str
status: str
error: Optional[str] = None
class BatchProcessFilesResponse(BaseModel):
results: List[BatchProcessFilesResult]
errors: List[BatchProcessFilesResult]
@router.post("/process/files/batch")
def process_files_batch(
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
"""
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name
# Prepare all documents first
all_docs: List[Document] = []
for file in form_data.files:
try:
text_content = file.data.get("content", "")
docs: List[Document] = [
Document(
page_content=text_content.replace("<br/>", "\n"),
metadata={
**file.meta,
"name": file.filename,
"created_by": file.user_id,
"file_id": file.id,
"source": file.filename,
},
)
]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs)
results.append(BatchProcessFilesResult(
file_id=file.id,
status="prepared"
))
except Exception as e:
log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
errors.append(BatchProcessFilesResult(
file_id=file.id,
status="failed",
error=str(e)
))
# Save all documents in one batch
if all_docs:
try:
save_docs_to_vector_db(
docs=all_docs,
collection_name=collection_name,
add=True
)
# Update all files with collection name
for result in results:
Files.update_file_metadata_by_id(
result.file_id,
{"collection_name": collection_name}
)
result.status = "completed"
except Exception as e:
log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
for result in results:
result.status = "failed"
errors.append(BatchProcessFilesResult(
file_id=result.file_id,
error=str(e)
))
return BatchProcessFilesResponse(
results=results,
errors=errors
)