mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: tags
This commit is contained in:
@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
|
||||
Chats,
|
||||
ChatTitleIdResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import (
|
||||
ChatIdTagForm,
|
||||
ChatIdTagModel,
|
||||
TagModel,
|
||||
Tags,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tags
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -126,6 +122,19 @@ async def search_user_chats(
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/pinned", response_model=list[ChatResponse])
|
||||
async def get_user_pinned_chats(user=Depends(get_verified_user)):
|
||||
return [
|
||||
ChatResponse(**chat.model_dump())
|
||||
for chat in Chats.get_pinned_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetChats
|
||||
############################
|
||||
@@ -152,6 +161,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all/tags", response_model=list[TagModel])
|
||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetAllChatsInDB
|
||||
############################
|
||||
@@ -220,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
class TagNameForm(BaseModel):
|
||||
class TagForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class TagFilterForm(TagForm):
|
||||
skip: Optional[int] = 0
|
||||
limit: Optional[int] = 50
|
||||
|
||||
|
||||
@router.post("/tags", response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_tag_name(
|
||||
form_data: TagNameForm, user=Depends(get_verified_user)
|
||||
form_data: TagFilterForm, user=Depends(get_verified_user)
|
||||
):
|
||||
chat_ids = [
|
||||
chat_id_tag.chat_id
|
||||
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
|
||||
form_data.name, user.id
|
||||
)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
|
||||
|
||||
chats = Chats.get_chat_list_by_user_id_and_tag_name(
|
||||
user.id, form_data.name, form_data.skip, form_data.limit
|
||||
)
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
|
||||
return chats
|
||||
|
||||
|
||||
############################
|
||||
# GetAllTags
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/tags/all", response_model=list[TagModel])
|
||||
async def get_all_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChatById
|
||||
############################
|
||||
@@ -324,12 +330,45 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# GetPinnedStatusById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/pinned", response_model=Optional[bool])
|
||||
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
return chat.pinned
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# PinChatById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/pin", response_model=Optional[ChatResponse])
|
||||
async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_pinned_by_id(id)
|
||||
return chat
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# CloneChat
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
@router.post("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
@@ -353,7 +392,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
@router.post("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
@@ -423,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
@router.get("/{id}/tags", response_model=list[TagModel])
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if tags != None:
|
||||
return tags
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -438,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
|
||||
async def add_chat_tag_by_id(
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
@router.post("/{id}/tags", response_model=list[TagModel])
|
||||
async def add_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
tags = chat.meta.get("tags", [])
|
||||
tag_id = form_data.name.replace(" ", "_").lower()
|
||||
|
||||
if form_data.tag_name not in tags:
|
||||
tag = Tags.add_tag_to_chat(user.id, form_data)
|
||||
|
||||
if tag:
|
||||
return tag
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
print(tags, tag_id)
|
||||
if tag_id not in tags:
|
||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
id, user.id, form_data.name
|
||||
)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
@@ -465,16 +506,20 @@ async def add_chat_tag_by_id(
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/tags", response_model=Optional[bool])
|
||||
async def delete_chat_tag_by_id(
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
|
||||
@router.delete("/{id}/tags", response_model=list[TagModel])
|
||||
async def delete_tag_by_id_and_tag_name(
|
||||
id: str, form_data: TagForm, user=Depends(get_verified_user)
|
||||
):
|
||||
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||
form_data.tag_name, id, user.id
|
||||
)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
|
||||
|
||||
if result:
|
||||
return result
|
||||
if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
@@ -488,10 +533,17 @@ async def delete_chat_tag_by_id(
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
|
||||
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
Chats.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
tags = chat.meta.get("tags", [])
|
||||
return Tags.get_tags_by_ids(tags)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
||||
Reference in New Issue
Block a user