diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index c03abb233..6419ac8ee 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -245,6 +245,40 @@ class ChatTable: ) return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chat_title_id_list_by_user_id( + self, + user_id: str, + include_archived: bool = False, + skip: int = 0, + limit: int = 50, + ) -> List[ChatTitleIdResponse]: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # limit cols + .with_entities( + Chat.id, Chat.title, Chat.updated_at, Chat.created_at + ).all() + ) + # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass. + return list( + map( + lambda row: ChatTitleIdResponse.model_validate( + { + "id": row[0], + "title": row[1], + "updated_at": row[2], + "created_at": row[3], + } + ), + all_chats, + ) + ) + def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index d3ccb9cce..80308a451 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -45,7 +45,7 @@ router = APIRouter() async def get_session_user_chat_list( user=Depends(get_verified_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(user.id, skip, limit) + return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) ############################