diff --git a/backend/open_webui/apps/webui/models/knowledge.py b/backend/open_webui/apps/webui/models/knowledge.py index 828bc3c2a..698cccda0 100644 --- a/backend/open_webui/apps/webui/models/knowledge.py +++ b/backend/open_webui/apps/webui/models/knowledge.py @@ -126,28 +126,10 @@ class KnowledgeTable: ) -> Optional[KnowledgeModel]: try: with get_db() as db: + knowledge = self.get_knowledge_by_id(id=id) db.query(Knowledge).filter_by(id=id).update( { - **({"name": form_data.name} if form_data.name else {}), - **( - {"description": form_data.description} - if form_data.description - else {} - ), - **( - { - "data": ( - form_data.data - if overwrite - else { - **(self.get_knowledge_by_id(id=id)).data, - **form_data.data, - } - ) - } - if form_data.data - else {} - ), + **form_data.model_dump(exclude_none=True), "updated_at": int(time.time()), } ) diff --git a/backend/open_webui/apps/webui/routers/knowledge.py b/backend/open_webui/apps/webui/routers/knowledge.py index 78faae6e1..29316258d 100644 --- a/backend/open_webui/apps/webui/routers/knowledge.py +++ b/backend/open_webui/apps/webui/routers/knowledge.py @@ -15,6 +15,9 @@ from open_webui.apps.webui.models.files import Files, FileModel from open_webui.constants import ERROR_MESSAGES from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT + + router = APIRouter() ############################ @@ -96,7 +99,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): ############################ -@router.post("/{id}/update", response_model=Optional[KnowledgeResponse]) +@router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( id: str, form_data: KnowledgeUpdateForm, @@ -105,7 +108,13 @@ async def update_knowledge_by_id( knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: - return knowledge + file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -113,6 +122,128 @@ async def update_knowledge_by_id( ) +############################ +# AddFileToKnowledge +############################ + + +class KnowledgeFileIdForm(BaseModel): + file_id: str + + +@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse]) +async def add_file_to_knowledge_by_id( + id: str, + form_data: KnowledgeFileIdForm, + user=Depends(get_admin_user), +): + knowledge = Knowledges.get_knowledge_by_id(id=id) + file = Files.get_file_by_id(form_data.file_id) + if not file: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if knowledge: + data = knowledge.data or {} + file_ids = data.get("file_ids", []) + + if form_data.file_id not in file_ids: + file_ids.append(form_data.file_id) + data["file_ids"] = file_ids + + knowledge = Knowledges.update_knowledge_by_id( + id=id, form_data=KnowledgeUpdateForm(data=data) + ) + + if knowledge: + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("knowledge"), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("file_id"), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# RemoveFileFromKnowledge +############################ + + +class KnowledgeFileIdForm(BaseModel): + file_id: str + + +@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse]) +async def remove_file_from_knowledge_by_id( + id: str, + form_data: KnowledgeFileIdForm, + user=Depends(get_admin_user), +): + knowledge = Knowledges.get_knowledge_by_id(id=id) + file = Files.get_file_by_id(form_data.file_id) + if not file: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + VECTOR_DB_CLIENT.delete( + collection_name=knowledge.id, filter={"file_id": form_data.file_id} + ) + + if knowledge: + data = knowledge.data or {} + file_ids = data.get("file_ids", []) + + if form_data.file_id in file_ids: + file_ids.remove(form_data.file_id) + data["file_ids"] = file_ids + + knowledge = Knowledges.update_knowledge_by_id( + id=id, form_data=KnowledgeUpdateForm(data=data) + ) + + if knowledge: + files = Files.get_files_by_ids(file_ids) + + return KnowledgeFilesResponse( + **knowledge.model_dump(), + files=files, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("knowledge"), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("file_id"), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # DeleteKnowledgeById ############################