diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index f5e480003..f72ed79b3 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list( return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + if user.role == "pending": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role == "user": + chat = Chats.get_chat_by_share_id(share_id) + elif user.role == "admin": + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + ############################ # GetChats ############################ @@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ) +############################ +# GetChatsByTags +############################ + + +class TagNameForm(BaseModel): + name: str + skip: Optional[int] = 0 + limit: Optional[int] = 50 + + +@router.post("/tags", response_model=List[ChatTitleIdResponse]) +async def get_user_chat_list_by_tag_name( + form_data: TagNameForm, user=Depends(get_current_user) +): + + print(form_data) + chat_ids = [ + chat_id_tag.chat_id + for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( + form_data.name, user.id + ) + ] + + chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) + + if len(chats) == 0: + Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + + return chats + + +############################ +# GetAllTags +############################ + + +@router.get("/tags/all", response_model=List[TagModel]) +async def get_all_tags(user=Depends(get_current_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetChatById ############################ @@ -274,79 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): ) -############################ -# GetSharedChatById -############################ - - -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) - elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) - - if chat: - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - -############################ -# GetAllTags -############################ - - -@router.get("/tags/all", response_model=List[TagModel]) -async def get_all_tags(user=Depends(get_current_user)): - try: - tags = Tags.get_tags_by_user_id(user.id) - return tags - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() - ) - - -############################ -# GetChatsByTags -############################ - - -class TagNameForm(BaseModel): - name: str - - -@router.post("/tags", response_model=List[ChatTitleIdResponse]) -async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, - user=Depends(get_current_user), - skip: int = 0, - limit: int = 50, -): - chat_ids = [ - chat_id_tag.chat_id - for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - form_data.name, user.id - ) - ] - - chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit) - - if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) - - return chats - - ############################ # GetChatTagsById ############################