enh: channel notification

This commit is contained in:
Timothy Jaeryang Baek 2024-12-25 00:53:25 -07:00
parent 0d7d6899b9
commit d701b69e05
5 changed files with 126 additions and 30 deletions

View File

@ -146,6 +146,13 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id( def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:

View File

@ -154,13 +154,25 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]: def get_users(
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[UserModel]:
with get_db() as db: with get_db() as db:
users = (
db.query(User) query = db.query(User).order_by(User.created_at.desc())
# .offset(skip).limit(limit)
.all() if skip:
) query = query.offset(skip)
if limit:
query = query.limit(limit)
users = query.all()
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
with get_db() as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
@ -179,11 +191,15 @@ class UsersTable:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return (
user.settings.get("ui", {}) if user.settings is None:
.get("notifications", {}) return None
.get("webhook_url", None) else:
) return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception: except Exception:
return None return None

View File

@ -3,11 +3,11 @@ import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.socket.main import sio from open_webui.socket.main import sio, SESSION_POOL
from open_webui.models.users import Users, UserNameResponse from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels, ChannelModel, ChannelForm from open_webui.models.channels import Channels, ChannelModel, ChannelForm
@ -16,11 +16,12 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS, WEBUI_URL
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -180,9 +181,39 @@ async def get_channel_messages(
############################ ############################
async def send_notification(channel, message, active_user_ids):
print(f"Sending notification to {channel=}, {message=}, {active_user_ids=}")
users = get_users_with_access("read", channel.access_control)
for user in users:
if user.id in active_user_ids:
continue
else:
if user.settings:
webhook_url = user.settings.ui.get("notifications", {}).get(
"webhook_url", None
)
if webhook_url:
post_webhook(
webhook_url,
f"#{channel.name} - {WEBUI_URL}/c/{channel.id}\n\n{message.content}",
{
"action": "channel",
"message": message.content,
"title": channel.name,
"url": f"{WEBUI_URL}/c/{channel.id}",
},
)
@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) @router.post("/{id}/messages/post", response_model=Optional[MessageModel])
async def post_new_message( async def post_new_message(
id: str, form_data: MessageForm, user=Depends(get_verified_user) id: str,
form_data: MessageForm,
background_tasks: BackgroundTasks,
user=Depends(get_verified_user),
): ):
channel = Channels.get_channel_by_id(id) channel = Channels.get_channel_by_id(id)
if not channel: if not channel:
@ -201,24 +232,44 @@ async def post_new_message(
message = Messages.insert_new_message(form_data, channel.id, user.id) message = Messages.insert_new_message(form_data, channel.id, user.id)
if message: if message:
event_data = {
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
}
await sio.emit( await sio.emit(
"channel-events", "channel-events",
{ event_data,
"channel_id": channel.id,
"message_id": message.id,
"data": {
"type": "message",
"data": {
**message.model_dump(),
"user": UserNameResponse(**user.model_dump()).model_dump(),
},
},
"user": UserNameResponse(**user.model_dump()).model_dump(),
"channel": channel.model_dump(),
},
to=f"channel:{channel.id}", to=f"channel:{channel.id}",
) )
active_session_ids = sio.manager.get_participants(
namespace="/",
room=f"channel:{channel.id}",
)
active_user_ids = list(
set(
[
SESSION_POOL.get(session_id[0])
for session_id in active_session_ids
]
)
)
background_tasks.add_task(
send_notification, channel, message, active_user_ids
)
return MessageModel(**message.model_dump()) return MessageModel(**message.model_dump())
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View File

@ -1,4 +1,5 @@
from typing import Optional, Union, List, Dict, Any from typing import Optional, Union, List, Dict, Any
from open_webui.models.users import Users, UserModel
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
import json import json
@ -93,3 +94,24 @@ def has_access(
return user_id in permitted_user_ids or any( return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids group_id in permitted_group_ids for group_id in user_group_ids
) )
# Get all users with access to a resource
def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None
) -> List[UserModel]:
if access_control is None:
return Users.get_users()
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
user_ids_with_access = set(permitted_user_ids)
for group_id in permitted_group_ids:
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
if group_user_ids:
user_ids_with_access.update(group_user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access))

View File

@ -21,7 +21,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
elif "https://discord.com/api/webhooks" in url: elif "https://discord.com/api/webhooks" in url:
payload["content"] = ( payload["content"] = (
message message
if len(message) > 2000 if len(message) < 2000
else f"{message[: 2000 - 20]}... (truncated)" else f"{message[: 2000 - 20]}... (truncated)"
) )
# Microsoft Teams Webhooks # Microsoft Teams Webhooks