refac/fix: reply to message recursion issue

This commit is contained in:
Timothy Jaeryang Baek
2026-01-01 03:10:09 +04:00
parent b67796465e
commit c144122f60

View File

@@ -138,7 +138,11 @@ class MessageResponse(MessageReplyToResponse):
class MessageTable:
def insert_new_message(
self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None
self,
form_data: MessageForm,
channel_id: str,
user_id: str,
db: Optional[Session] = None,
) -> Optional[MessageModel]:
with get_db_context(db) as db:
channel_member = Channels.join_channel(channel_id, user_id)
@@ -170,20 +174,30 @@ class MessageTable:
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]:
def get_message_by_id(
self,
id: str,
include_thread_replies: Optional[bool] = True,
db: Optional[Session] = None,
) -> Optional[MessageResponse]:
with get_db_context(db) as db:
message = db.get(Message, id)
if not message:
return None
reply_to_message = (
self.get_message_by_id(message.reply_to_id, db=db)
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id
else None
)
reactions = self.get_reactions_by_message_id(id, db=db)
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
thread_replies = []
if include_thread_replies:
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
user = Users.get_user_by_id(message.user_id, db=db)
return MessageResponse.model_validate(
@@ -201,7 +215,9 @@ class MessageTable:
}
)
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
def get_thread_replies_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
all_messages = (
db.query(Message)
@@ -213,7 +229,9 @@ class MessageTable:
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, db=db)
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id
else None
)
@@ -231,7 +249,9 @@ class MessageTable:
)
return messages
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
def get_reply_user_ids_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[str]:
with get_db_context(db) as db:
return [
message.user_id
@@ -239,7 +259,11 @@ class MessageTable:
]
def get_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
self,
channel_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
all_messages = (
@@ -254,7 +278,9 @@ class MessageTable:
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, db=db)
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id
else None
)
@@ -273,7 +299,12 @@ class MessageTable:
return messages
def get_messages_by_parent_id(
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
self,
channel_id: str,
parent_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
message = db.get(Message, parent_id)
@@ -297,7 +328,9 @@ class MessageTable:
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, db=db)
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id
else None
)
@@ -315,7 +348,9 @@ class MessageTable:
)
return messages
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
def get_last_message_by_channel_id(
self, channel_id: str, db: Optional[Session] = None
) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = (
db.query(Message)
@@ -326,7 +361,11 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id(
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
self,
channel_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageModel]:
with get_db_context(db) as db:
all_messages = (
@@ -359,7 +398,11 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None
def update_is_pinned_by_id(
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None
self,
id: str,
is_pinned: bool,
pinned_by: Optional[str] = None,
db: Optional[Session] = None,
) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = db.get(Message, id)
@@ -371,7 +414,11 @@ class MessageTable:
return MessageModel.model_validate(message) if message else None
def get_unread_message_count(
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None
self,
channel_id: str,
user_id: str,
last_read_at: Optional[int] = None,
db: Optional[Session] = None,
) -> int:
with get_db_context(db) as db:
query = db.query(Message).filter(
@@ -410,7 +457,9 @@ class MessageTable:
db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
def get_reactions_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[Reactions]:
with get_db_context(db) as db:
# JOIN User so all user info is fetched in one query
results = (