Add batching

This commit is contained in:
Gabriel Ecegi 2024-12-13 15:29:43 +01:00
parent bfdbb2df69
commit f2e2b59c18
2 changed files with 182 additions and 14 deletions

View File

@ -2,16 +2,14 @@
import json
import logging
import mimetypes
import os
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from typing import Iterator, Optional, Sequence, Union
from typing import List, Optional
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import tiktoken
@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
query_doc_with_hybrid_search,
)
from open_webui.apps.webui.models.files import Files
from open_webui.apps.webui.models.files import FileModel, Files
from open_webui.config import (
BRAVE_SEARCH_API_KEY,
KAGI_SEARCH_API_KEY,
@ -64,7 +62,6 @@ from open_webui.config import (
CONTENT_EXTRACTION_ENGINE,
CORS_ALLOW_ORIGIN,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
ENV,
@ -86,7 +83,6 @@ from open_webui.config import (
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
DEFAULT_RAG_TEMPLATE,
RAG_TEMPLATE,
RAG_TOP_K,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@ -118,10 +114,7 @@ from open_webui.env import (
DOCKER,
)
from open_webui.utils.misc import (
calculate_sha256,
calculate_sha256_string,
extract_folders_after_data_docs,
sanitize_filename,
)
from open_webui.utils.auth import get_admin_user, get_verified_user
@ -1047,6 +1040,106 @@ def process_file(
)
class BatchProcessFilesForm(BaseModel):
files: List[FileModel]
collection_name: str
class BatchProcessFilesResult(BaseModel):
file_id: str
status: str
error: Optional[str] = None
class BatchProcessFilesResponse(BaseModel):
results: List[BatchProcessFilesResult]
errors: List[BatchProcessFilesResult]
@app.post("/process/files/batch")
def process_files_batch(
form_data: BatchProcessFilesForm,
user=Depends(get_verified_user),
) -> BatchProcessFilesResponse:
"""
Process a batch of files and save them to the vector database.
"""
results: List[BatchProcessFilesResult] = []
errors: List[BatchProcessFilesResult] = []
collection_name = form_data.collection_name
# Prepare all documents first
all_docs: List[Document] = []
for file_request in form_data.files:
try:
file = Files.get_file_by_id(file_request.file_id)
if not file:
log.error(f"process_files_batch: File {file_request.file_id} not found")
raise ValueError(f"File {file_request.file_id} not found")
text_content = file_request.content
docs: List[Document] = [
Document(
page_content=text_content.replace("<br/>", "\n"),
metadata={
**file.meta,
"name": file_request.filename,
"created_by": file.user_id,
"file_id": file.id,
"source": file_request.filename,
},
)
]
hash = calculate_sha256_string(text_content)
Files.update_file_hash_by_id(file.id, hash)
Files.update_file_data_by_id(file.id, {"content": text_content})
all_docs.extend(docs)
results.append(BatchProcessFilesResult(
file_id=file.id,
status="prepared"
))
except Exception as e:
log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
errors.append(BatchProcessFilesResult(
file_id=file_request.file_id,
status="failed",
error=str(e)
))
# Save all documents in one batch
if all_docs:
try:
save_docs_to_vector_db(
docs=all_docs,
collection_name=collection_name,
add=True
)
# Update all files with collection name
for result in results:
Files.update_file_metadata_by_id(
result.file_id,
{"collection_name": collection_name}
)
result.status = "completed"
except Exception as e:
log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
for result in results:
result.status = "failed"
errors.append(BatchProcessFilesResult(
file_id=result.file_id,
error=str(e)
))
return BatchProcessFilesResponse(
results=results,
errors=errors
)
class ProcessTextForm(BaseModel):
name: str
content: str
@ -1509,3 +1602,4 @@ if ENV == "dev":
@app.get("/ef/{text}")
async def get_embeddings_text(text: str):
return {"result": app.state.EMBEDDING_FUNCTION(text)}

View File

@ -1,5 +1,4 @@
import json
from typing import Optional, Union
from typing import List, Optional
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Request
import logging
@ -12,11 +11,11 @@ from open_webui.apps.webui.models.knowledge import (
)
from open_webui.apps.webui.models.files import Files, FileModel
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
from open_webui.apps.retrieval.main import BatchProcessFilesForm, process_file, ProcessFileForm, process_files_batch
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.auth import get_verified_user
from open_webui.utils.access_control import has_access, has_permission
@ -508,3 +507,78 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
return knowledge
############################
# AddFilesToKnowledge
############################
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
def add_files_to_knowledge_batch(
id: str,
form_data: list[KnowledgeFileIdForm],
user=Depends(get_verified_user),
):
"""
Add multiple files to a knowledge base
"""
knowledge = Knowledges.get_knowledge_by_id(id=id)
if not knowledge:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge.user_id != user.id and user.role != "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
# Get files content
print(f"files/batch/add - {len(form_data)} files")
files: List[FileModel] = []
for form in form_data:
file = Files.get_file_by_id(form.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File {form.file_id} not found",
)
files.append(file)
# Process files
result = process_files_batch(BatchProcessFilesForm(
files=files,
collection_name=id
))
# Add successful files to knowledge base
data = knowledge.data or {}
existing_file_ids = data.get("file_ids", [])
# Only add files that were successfully processed
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
for file_id in successful_file_ids:
if file_id not in existing_file_ids:
existing_file_ids.append(file_id)
data["file_ids"] = existing_file_ids
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
# If there were any errors, include them in the response
if result.errors:
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids),
warnings={
"message": "Some files failed to process",
"errors": error_details
}
)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Files.get_files_by_ids(existing_file_ids)
)