From 2b78e613a490f86425899bc145adb6b17dc607e5 Mon Sep 17 00:00:00 2001 From: Aryan Kothari <87589047+thearyadev@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:08:15 -0400 Subject: [PATCH 1/3] add func to get chat list with more specific sql query --- backend/apps/webui/models/chats.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index c03abb233..b668d97d7 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -244,6 +244,32 @@ class ChatTable: .all() ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chat_title_id_list_by_user_id( + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, + ) -> List[ChatTitleIdResponse]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # limit cols + .with_entities(Chat.id, Chat.title, Chat.updated_at, Chat.created_at) + .all() + ) + # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. + return list(map(lambda row: ChatTitleIdResponse.model_validate({ + "id": row[0], + "title": row[1], + "updated_at": row[2], + "created_at": row[3] + }), all_chats)) + def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 From a0667dfd1b095900a1400f953a99b461cc82baa5 Mon Sep 17 00:00:00 2001 From: Aryan Kothari <87589047+thearyadev@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:09:22 -0400 Subject: [PATCH 2/3] change `/chats/` and `/chats/list` to utilize new function --- backend/apps/webui/routers/chats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index d3ccb9cce..80308a451 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -45,7 +45,7 @@ router = APIRouter() async def get_session_user_chat_list( user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user.id, skip, limit) + return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) ############################ From f531a51e91e83aadb4c862215fbb40b158077c0f Mon Sep 17 00:00:00 2001 From: Aryan Kothari <87589047+thearyadev@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:45:47 -0400 Subject: [PATCH 3/3] chore: formatting --- backend/apps/webui/models/chats.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index b668d97d7..6419ac8ee 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -244,13 +244,14 @@ class ChatTable: .all() ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chat_title_id_list_by_user_id( self, user_id: str, include_archived: bool = False, skip: int = 0, limit: int = 50, - ) -> List[ChatTitleIdResponse]: + ) -> List[ChatTitleIdResponse]: with get_db() as db: query = db.query(Chat).filter_by(user_id=user_id) if not include_archived: @@ -259,17 +260,24 @@ class ChatTable: all_chats = ( query.order_by(Chat.updated_at.desc()) # limit cols - .with_entities(Chat.id, Chat.title, Chat.updated_at, Chat.created_at) - .all() + .with_entities( + Chat.id, Chat.title, Chat.updated_at, Chat.created_at + ).all() ) # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. - return list(map(lambda row: ChatTitleIdResponse.model_validate({ - "id": row[0], - "title": row[1], - "updated_at": row[2], - "created_at": row[3] - }), all_chats)) - + return list( + map( + lambda row: ChatTitleIdResponse.model_validate( + { + "id": row[0], + "title": row[1], + "updated_at": row[2], + "created_at": row[3], + } + ), + all_chats, + ) + ) def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50