From c144122f608759c2b79472e1f6948a7c1600a3d1 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 1 Jan 2026 03:10:09 +0400 Subject: [PATCH] refac/fix: reply to message recursion issue --- backend/open_webui/models/messages.py | 81 +++++++++++++++++++++------ 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 8fe37e58e..c0e2007b2 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -138,7 +138,11 @@ class MessageResponse(MessageReplyToResponse): class MessageTable: def insert_new_message( - self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None + self, + form_data: MessageForm, + channel_id: str, + user_id: str, + db: Optional[Session] = None, ) -> Optional[MessageModel]: with get_db_context(db) as db: channel_member = Channels.join_channel(channel_id, user_id) @@ -170,20 +174,30 @@ class MessageTable: db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]: + def get_message_by_id( + self, + id: str, + include_thread_replies: Optional[bool] = True, + db: Optional[Session] = None, + ) -> Optional[MessageResponse]: with get_db_context(db) as db: message = db.get(Message, id) if not message: return None reply_to_message = ( - self.get_message_by_id(message.reply_to_id, db=db) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) reactions = self.get_reactions_by_message_id(id, db=db) - thread_replies = self.get_thread_replies_by_message_id(id, db=db) + + thread_replies = [] + if include_thread_replies: + thread_replies = self.get_thread_replies_by_message_id(id, db=db) user = Users.get_user_by_id(message.user_id, db=db) return MessageResponse.model_validate( @@ -201,7 +215,9 @@ class MessageTable: } ) - def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]: + def get_thread_replies_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[MessageReplyToResponse]: with get_db_context(db) as db: all_messages = ( db.query(Message) @@ -213,7 +229,9 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id, db=db) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) @@ -231,7 +249,9 @@ class MessageTable: ) return messages - def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]: + def get_reply_user_ids_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[str]: with get_db_context(db) as db: return [ message.user_id @@ -239,7 +259,11 @@ class MessageTable: ] def get_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None + self, + channel_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageReplyToResponse]: with get_db_context(db) as db: all_messages = ( @@ -254,7 +278,9 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id, db=db) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) @@ -273,7 +299,12 @@ class MessageTable: return messages def get_messages_by_parent_id( - self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None + self, + channel_id: str, + parent_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageReplyToResponse]: with get_db_context(db) as db: message = db.get(Message, parent_id) @@ -297,7 +328,9 @@ class MessageTable: messages = [] for message in all_messages: reply_to_message = ( - self.get_message_by_id(message.reply_to_id, db=db) + self.get_message_by_id( + message.reply_to_id, include_thread_replies=False, db=db + ) if message.reply_to_id else None ) @@ -315,7 +348,9 @@ class MessageTable: ) return messages - def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]: + def get_last_message_by_channel_id( + self, channel_id: str, db: Optional[Session] = None + ) -> Optional[MessageModel]: with get_db_context(db) as db: message = ( db.query(Message) @@ -326,7 +361,11 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def get_pinned_messages_by_channel_id( - self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None + self, + channel_id: str, + skip: int = 0, + limit: int = 50, + db: Optional[Session] = None, ) -> list[MessageModel]: with get_db_context(db) as db: all_messages = ( @@ -359,7 +398,11 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def update_is_pinned_by_id( - self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None + self, + id: str, + is_pinned: bool, + pinned_by: Optional[str] = None, + db: Optional[Session] = None, ) -> Optional[MessageModel]: with get_db_context(db) as db: message = db.get(Message, id) @@ -371,7 +414,11 @@ class MessageTable: return MessageModel.model_validate(message) if message else None def get_unread_message_count( - self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None + self, + channel_id: str, + user_id: str, + last_read_at: Optional[int] = None, + db: Optional[Session] = None, ) -> int: with get_db_context(db) as db: query = db.query(Message).filter( @@ -410,7 +457,9 @@ class MessageTable: db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]: + def get_reactions_by_message_id( + self, id: str, db: Optional[Session] = None + ) -> list[Reactions]: with get_db_context(db) as db: # JOIN User so all user info is fetched in one query results = (