From 2cb0bf44390adeefd8c527962369b24f4b0a4184 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 26 Dec 2023 01:27:43 -0800 Subject: [PATCH] fix: chat return type to dict --- backend/apps/web/models/chats.py | 22 ++++++++++++++-- backend/apps/web/models/users.py | 3 --- backend/apps/web/routers/chats.py | 42 +++++++++++++++++++++++-------- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index cd915d9a3..ad9a28778 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -44,8 +44,12 @@ class ChatForm(BaseModel): chat: dict -class ChatUpdateForm(ChatForm): +class ChatResponse(BaseModel): id: str + user_id: str + title: str + chat: dict + timestamp: int # timestamp in epoch class ChatTitleIdResponse(BaseModel): @@ -77,7 +81,11 @@ class ChatTable: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - query = Chat.update(chat=json.dumps(chat)).where(Chat.id == id) + query = Chat.update( + chat=json.dumps(chat), + title=chat["title"] if "title" in chat else "New Chat", + timestamp=int(time.time()), + ).where(Chat.id == id) query.execute() chat = Chat.get(Chat.id == id) @@ -92,6 +100,7 @@ class ChatTable: ChatModel(**model_to_dict(chat)) for chat in Chat.select() .where(Chat.user_id == user_id) + .order_by(Chat.timestamp.desc()) .limit(limit) .offset(skip) ] @@ -109,5 +118,14 @@ class ChatTable: for chat in Chat.select().limit(limit).offset(skip) ] + def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: + try: + query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) + query.execute() # Remove the rows, return number of rows removed. + + return True + except: + return False + Chats = ChatTable(DB) diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py index 88414999e..782b7f47e 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/web/models/users.py @@ -27,9 +27,6 @@ class User(Model): class UserModel(BaseModel): - class Config: - orm_mode = True - id: str name: str email: str diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 64cee04e2..3191db8be 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -5,12 +5,13 @@ from typing import List, Union, Optional from fastapi import APIRouter from pydantic import BaseModel +import json from apps.web.models.users import Users from apps.web.models.chats import ( ChatModel, + ChatResponse, ChatForm, - ChatUpdateForm, ChatTitleIdResponse, Chats, ) @@ -46,13 +47,14 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch ############################ -@router.post("/new", response_model=Optional[ChatModel]) +@router.post("/new", response_model=Optional[ChatResponse]) async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): token = cred.credentials user = Users.get_user_by_token(token) if user: - return Chats.insert_new_chat(user.id, form_data) + chat = Chats.insert_new_chat(user.id, form_data) + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -65,13 +67,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): ############################ -@router.get("/{id}", response_model=Optional[ChatModel]) +@router.get("/{id}", response_model=Optional[ChatResponse]) async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): token = cred.credentials user = Users.get_user_by_token(token) if user: - return Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -84,17 +87,16 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): ############################ -@router.post("/{id}", response_model=Optional[ChatModel]) -async def update_chat_by_id( - id: str, form_data: ChatUpdateForm, cred=Depends(bearer_scheme) -): +@router.post("/{id}", response_model=Optional[ChatResponse]) +async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): token = cred.credentials user = Users.get_user_by_token(token) if user: chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - return Chats.update_chat_by_id(id, form_data.chat) + chat = Chats.update_chat_by_id(id, form_data.chat) + return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -105,3 +107,23 @@ async def update_chat_by_id( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.INVALID_TOKEN, ) + + +############################ +# DeleteChatById +############################ + + +@router.delete("/{id}", response_model=bool) +async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): + token = cred.credentials + user = Users.get_user_by_token(token) + + if user: + result = Chats.delete_chat_by_id_and_user_id(id, user.id) + return result + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.INVALID_TOKEN, + )