refac: tags

This commit is contained in:
Timothy J. Baek
2024-10-10 23:22:53 -07:00
parent 4adc57fd34
commit acb5dcf30a
10 changed files with 555 additions and 291 deletions

View File

@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
Chats,
ChatTitleIdResponse,
)
from open_webui.apps.webui.models.tags import (
ChatIdTagForm,
ChatIdTagModel,
TagModel,
Tags,
)
from open_webui.apps.webui.models.tags import TagModel, Tags
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS
@@ -126,6 +122,19 @@ async def search_user_chats(
]
############################
# GetPinnedChats
############################
@router.get("/pinned", response_model=list[ChatResponse])
async def get_user_pinned_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**chat.model_dump())
for chat in Chats.get_pinned_chats_by_user_id(user.id)
]
############################
# GetChats
############################
@@ -152,6 +161,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
]
############################
# GetAllTags
############################
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetAllChatsInDB
############################
@@ -220,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
############################
class TagNameForm(BaseModel):
class TagForm(BaseModel):
name: str
class TagFilterForm(TagForm):
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user)
form_data: TagFilterForm, user=Depends(get_verified_user)
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
chats = Chats.get_chat_list_by_user_id_and_tag_name(
user.id, form_data.name, form_data.skip, form_data.limit
)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=list[TagModel])
async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatById
############################
@@ -324,12 +330,45 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
return result
############################
# GetPinnedStatusById
############################
@router.get("/{id}/pinned", response_model=Optional[bool])
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return chat.pinned
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# PinChatById
############################
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_pinned_by_id(id)
return chat
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
@@ -353,7 +392,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
@@ -423,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/tags", response_model=list[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
return tags
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -438,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.post("/{id}/tags", response_model=list[TagModel])
async def add_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
tags = chat.meta.get("tags", [])
tag_id = form_data.name.replace(" ", "_").lower()
if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
print(tags, tag_id)
if tag_id not in tags:
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
id, user.id, form_data.name
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -465,16 +506,20 @@ async def add_chat_tag_by_id(
############################
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
@router.delete("/{id}/tags", response_model=list[TagModel])
async def delete_tag_by_id_and_tag_name(
id: str, form_data: TagForm, user=Depends(get_verified_user)
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
if result:
return result
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -488,10 +533,17 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
if result:
return result
for tag in chat.meta.get("tags", []):
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
tags = chat.meta.get("tags", [])
return Tags.get_tags_by_ids(tags)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND