refac/fix: reply to message recursion issue
This commit is contained in:
@@ -138,7 +138,11 @@ class MessageResponse(MessageReplyToResponse):
|
||||
|
||||
class MessageTable:
|
||||
def insert_new_message(
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
self,
|
||||
form_data: MessageForm,
|
||||
channel_id: str,
|
||||
user_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
channel_member = Channels.join_channel(channel_id, user_id)
|
||||
@@ -170,20 +174,30 @@ class MessageTable:
|
||||
db.refresh(result)
|
||||
return MessageModel.model_validate(result) if result else None
|
||||
|
||||
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MessageResponse]:
|
||||
def get_message_by_id(
|
||||
self,
|
||||
id: str,
|
||||
include_thread_replies: Optional[bool] = True,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[MessageResponse]:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, id)
|
||||
if not message:
|
||||
return None
|
||||
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
||||
reactions = self.get_reactions_by_message_id(id, db=db)
|
||||
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
||||
|
||||
thread_replies = []
|
||||
if include_thread_replies:
|
||||
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
||||
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
return MessageResponse.model_validate(
|
||||
@@ -201,7 +215,9 @@ class MessageTable:
|
||||
}
|
||||
)
|
||||
|
||||
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
|
||||
def get_thread_replies_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[MessageReplyToResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
@@ -213,7 +229,9 @@ class MessageTable:
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
@@ -231,7 +249,9 @@ class MessageTable:
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||
def get_reply_user_ids_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
message.user_id
|
||||
@@ -239,7 +259,11 @@ class MessageTable:
|
||||
]
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
self,
|
||||
channel_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[MessageReplyToResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
@@ -254,7 +278,9 @@ class MessageTable:
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
@@ -273,7 +299,12 @@ class MessageTable:
|
||||
return messages
|
||||
|
||||
def get_messages_by_parent_id(
|
||||
self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
self,
|
||||
channel_id: str,
|
||||
parent_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[MessageReplyToResponse]:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, parent_id)
|
||||
@@ -297,7 +328,9 @@ class MessageTable:
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(message.reply_to_id, db=db)
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
@@ -315,7 +348,9 @@ class MessageTable:
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
|
||||
def get_last_message_by_channel_id(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
message = (
|
||||
db.query(Message)
|
||||
@@ -326,7 +361,11 @@ class MessageTable:
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_pinned_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
self,
|
||||
channel_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
@@ -359,7 +398,11 @@ class MessageTable:
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def update_is_pinned_by_id(
|
||||
self, id: str, is_pinned: bool, pinned_by: Optional[str] = None, db: Optional[Session] = None
|
||||
self,
|
||||
id: str,
|
||||
is_pinned: bool,
|
||||
pinned_by: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(Message, id)
|
||||
@@ -371,7 +414,11 @@ class MessageTable:
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_unread_message_count(
|
||||
self, channel_id: str, user_id: str, last_read_at: Optional[int] = None, db: Optional[Session] = None
|
||||
self,
|
||||
channel_id: str,
|
||||
user_id: str,
|
||||
last_read_at: Optional[int] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> int:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Message).filter(
|
||||
@@ -410,7 +457,9 @@ class MessageTable:
|
||||
db.refresh(result)
|
||||
return MessageReactionModel.model_validate(result) if result else None
|
||||
|
||||
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
|
||||
def get_reactions_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[Reactions]:
|
||||
with get_db_context(db) as db:
|
||||
# JOIN User so all user info is fetched in one query
|
||||
results = (
|
||||
|
||||
Reference in New Issue
Block a user