From e20bb234096219c6931a36fdf9d45128562db746 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sun, 26 May 2024 02:00:31 -0700 Subject: [PATCH] feat: access archived chats as admin --- backend/apps/webui/models/chats.py | 48 +++++++--- backend/apps/webui/routers/chats.py | 88 +++++++++++-------- .../chat/Messages/ResponseMessage.svelte | 2 +- src/lib/utils/index.ts | 4 +- 4 files changed, 90 insertions(+), 52 deletions(-) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index a07ec5e30..d4597f16d 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -191,6 +191,20 @@ class ChatTable: except: return None + def archive_all_chats_by_user_id(self, user_id: str) -> bool: + try: + chats = self.get_chats_by_user_id(user_id) + for chat in chats: + query = Chat.update( + archived=True, + ).where(Chat.id == chat.id) + + query.execute() + + return True + except: + return False + def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: @@ -205,17 +219,31 @@ class ChatTable: ] def get_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, ) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == False) - .where(Chat.user_id == user_id) - .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] + if include_archived: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] + else: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.archived == False) + .where(Chat.user_id == user_id) + .order_by(Chat.updated_at.desc()) + # .limit(limit) + # .offset(skip) + ] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index e0c824a8e..02c9335e0 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) async def get_user_chat_list_by_user_id( user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user_id, skip, limit) + return Chats.get_chat_list_by_user_id( + user_id, include_archived=True, skip=skip, limit=limit + ) ############################ -# GetArchivedChats +# CreateNewChat ############################ -@router.get("/archived", response_model=List[ChatTitleIdResponse]) -async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 -): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) - - -############################ -# GetSharedChatById -############################ - - -@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): - if user.role == "pending": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND - ) - - if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) - elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) - - if chat: +@router.post("/new", response_model=Optional[ChatResponse]) +async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): + try: + chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - else: + except Exception as e: + log.exception(e) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) @@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ############################ -# CreateNewChat +# GetArchivedChats ############################ -@router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): - try: - chat = Chats.insert_new_chat(user.id, form_data) - return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - except Exception as e: - log.exception(e) +@router.get("/archived", response_model=List[ChatTitleIdResponse]) +async def get_archived_session_user_chat_list( + user=Depends(get_current_user), skip: int = 0, limit: int = 50 +): + return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + + +############################ +# ArchiveAllChats +############################ + + +@router.get("/archive/all", response_model=List[ChatTitleIdResponse]) +async def archive_all_chats(user=Depends(get_current_user)): + return Chats.archive_all_chats_by_user_id(user.id) + + +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): + if user.role == "pending": raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role == "user": + chat = Chats.get_chat_by_share_id(share_id) + elif user.role == "admin": + chat = Chats.get_chat_by_id(share_id) + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) diff --git a/src/lib/components/chat/Messages/ResponseMessage.svelte b/src/lib/components/chat/Messages/ResponseMessage.svelte index 15a0501a3..cae705b64 100644 --- a/src/lib/components/chat/Messages/ResponseMessage.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage.svelte @@ -82,7 +82,7 @@ // Open all links in a new tab/window (from https://github.com/markedjs/marked/issues/655#issuecomment-383226346) const origLinkRenderer = renderer.link; - renderer.link = (href, title, text) => { + renderer.link = (href, title, text) => { const html = origLinkRenderer.call(renderer, href, title, text); return html.replace(/^ { }; export const revertSanitizedResponseContent = (content: string) => { - return content - .replaceAll('<', '<') - .replaceAll('>', '>'); + return content.replaceAll('<', '<').replaceAll('>', '>'); }; export const capitalizeFirstLetter = (string) => {