enh: add/remove file from knowledge

This commit is contained in:
Timothy J. Baek 2024-10-03 06:46:20 -07:00
parent 1c01e52f7c
commit 78413d0c2e
2 changed files with 135 additions and 22 deletions

View File

@ -126,28 +126,10 @@ class KnowledgeTable:
) -> Optional[KnowledgeModel]: ) -> Optional[KnowledgeModel]:
try: try:
with get_db() as db: with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
**({"name": form_data.name} if form_data.name else {}), **form_data.model_dump(exclude_none=True),
**(
{"description": form_data.description}
if form_data.description
else {}
),
**(
{
"data": (
form_data.data
if overwrite
else {
**(self.get_knowledge_by_id(id=id)).data,
**form_data.data,
}
)
}
if form_data.data
else {}
),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )

View File

@ -15,6 +15,9 @@ from open_webui.apps.webui.models.files import Files, FileModel
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
router = APIRouter() router = APIRouter()
############################ ############################
@ -96,7 +99,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
############################ ############################
@router.post("/{id}/update", response_model=Optional[KnowledgeResponse]) @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse])
async def update_knowledge_by_id( async def update_knowledge_by_id(
id: str, id: str,
form_data: KnowledgeUpdateForm, form_data: KnowledgeUpdateForm,
@ -105,7 +108,13 @@ async def update_knowledge_by_id(
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data)
if knowledge: if knowledge:
return knowledge file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_files_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=files,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -113,6 +122,128 @@ async def update_knowledge_by_id(
) )
############################
# AddFileToKnowledge
############################
class KnowledgeFileIdForm(BaseModel):
file_id: str
@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
async def add_file_to_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
file = Files.get_file_by_id(form_data.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if knowledge:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
if form_data.file_id not in file_ids:
file_ids.append(form_data.file_id)
data["file_ids"] = file_ids
knowledge = Knowledges.update_knowledge_by_id(
id=id, form_data=KnowledgeUpdateForm(data=data)
)
if knowledge:
files = Files.get_files_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=files,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("file_id"),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# RemoveFileFromKnowledge
############################
class KnowledgeFileIdForm(BaseModel):
file_id: str
@router.post("/{id}/file/remove", response_model=Optional[KnowledgeFilesResponse])
async def remove_file_from_knowledge_by_id(
id: str,
form_data: KnowledgeFileIdForm,
user=Depends(get_admin_user),
):
knowledge = Knowledges.get_knowledge_by_id(id=id)
file = Files.get_file_by_id(form_data.file_id)
if not file:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
VECTOR_DB_CLIENT.delete(
collection_name=knowledge.id, filter={"file_id": form_data.file_id}
)
if knowledge:
data = knowledge.data or {}
file_ids = data.get("file_ids", [])
if form_data.file_id in file_ids:
file_ids.remove(form_data.file_id)
data["file_ids"] = file_ids
knowledge = Knowledges.update_knowledge_by_id(
id=id, form_data=KnowledgeUpdateForm(data=data)
)
if knowledge:
files = Files.get_files_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=files,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("knowledge"),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("file_id"),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################ ############################
# DeleteKnowledgeById # DeleteKnowledgeById
############################ ############################