From bfbfdae1c5c8aa2f5b94d7f9738d70b061b4abe2 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Sun, 31 Mar 2024 22:02:40 +0100 Subject: [PATCH] feat: add backend functions for sharing chats --- backend/apps/web/models/chats.py | 41 ++++++++++++++++++ backend/apps/web/routers/chats.py | 71 +++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index c9d130044..5a49354ba 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -20,6 +20,7 @@ class Chat(Model): title = CharField() chat = TextField() # Save Chat JSON as Text timestamp = DateField() + share_id = CharField(null=True, unique=True) class Meta: database = DB @@ -31,6 +32,7 @@ class ChatModel(BaseModel): title: str chat: str timestamp: int # timestamp in epoch + share_id: Optional[str] = None #################### @@ -52,6 +54,7 @@ class ChatResponse(BaseModel): title: str chat: dict timestamp: int # timestamp in epoch + share_id: Optional[str] = None # id of the chat to be shared class ChatTitleIdResponse(BaseModel): @@ -95,6 +98,44 @@ class ChatTable: except: return None + def insert_shared_chat(self, chat_id: str) -> Optional[ChatModel]: + # Get the existing chat to share + chat = Chat.get(Chat.id == chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": "shared", + "title": chat.title, + "chat": chat.chat, + "timestamp": int(time.time()), + } + ) + shared_result = Chat.create(**shared_chat.model_dump()) + # Update the original chat with the share_id + result = ( + Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute() + ) + + return shared_chat if (shared_result and result) else None + + def update_chat_share_id_by_id( + self, od: str, share_id: Optional[str] + ) -> Optional[ChatModel]: + try: + query = Chat.update( + share_id=share_id, + ).where(Chat.id == id) + query.execute() + + chat = Chat.get(Chat.id == id) + return ChatModel(**model_to_dict(chat)) + except: + return None + def get_chat_lists_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 5f8c61b70..91b8a7343 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -189,6 +189,77 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ return result +############################ +# ShareChatById +############################ + + +@router.post("/{id}/share", response_model=Optional[ChatResponse]) +async def share_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + if chat.share_id: + shared_chat = Chats.get_chat_by_id_and_user_id(chat.share_id, "shared") + return ChatResponse( + **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} + ) + + shared_chat = Chats.insert_shared_chat(chat.id) + if not shared_chat: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + return ChatResponse( + **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# DeletedSharedChatById +############################ + + +@router.delete("/{id}/share", response_model=Optional[bool]) +async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + if chat: + if not chat.share_id: + return False + result = Chats.delete_chat_by_id_and_user_id(chat.share_id, "shared") + update_result = Chats.update_chat_share_id_by_id(chat.id, None) + + return result and update_result + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + + +############################ +# GetSharedChatById +############################ + + +@router.get("/share/{id}", response_model=Optional[ChatResponse]) +async def get_shared_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, "shared") + + if chat: + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + ) + + ############################ # GetChatTagsById ############################