diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 8f0728411..94f4cfae8 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -146,6 +146,13 @@ class GroupTable: except Exception: return None + def get_group_user_ids_by_id(self, id: str) -> Optional[str]: + group = self.get_group_by_id(id) + if group: + return group.user_ids + else: + return None + def update_group_by_id( self, id: str, form_data: GroupUpdateForm, overwrite: bool = False ) -> Optional[GroupModel]: diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 931711b9e..9ba127605 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -154,13 +154,25 @@ class UsersTable: except Exception: return None - def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]: + def get_users( + self, skip: Optional[int] = None, limit: Optional[int] = None + ) -> list[UserModel]: with get_db() as db: - users = ( - db.query(User) - # .offset(skip).limit(limit) - .all() - ) + + query = db.query(User).order_by(User.created_at.desc()) + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + users = query.all() + + return [UserModel.model_validate(user) for user in users] + + def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]: + with get_db() as db: + users = db.query(User).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: @@ -179,11 +191,15 @@ class UsersTable: try: with get_db() as db: user = db.query(User).filter_by(id=id).first() - return ( - user.settings.get("ui", {}) - .get("notifications", {}) - .get("webhook_url", None) - ) + + if user.settings is None: + return None + else: + return ( + user.settings.get("ui", {}) + .get("notifications", {}) + .get("webhook_url", None) + ) except Exception: return None diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index f97a31c67..32c348ce3 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -3,11 +3,11 @@ import logging from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks from pydantic import BaseModel -from open_webui.socket.main import sio +from open_webui.socket.main import sio, SESSION_POOL from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels, ChannelModel, ChannelForm @@ -16,11 +16,12 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS, WEBUI_URL from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_access +from open_webui.utils.access_control import has_access, get_users_with_access +from open_webui.utils.webhook import post_webhook log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -180,9 +181,39 @@ async def get_channel_messages( ############################ +async def send_notification(channel, message, active_user_ids): + + print(f"Sending notification to {channel=}, {message=}, {active_user_ids=}") + users = get_users_with_access("read", channel.access_control) + + for user in users: + if user.id in active_user_ids: + continue + else: + if user.settings: + webhook_url = user.settings.ui.get("notifications", {}).get( + "webhook_url", None + ) + + if webhook_url: + post_webhook( + webhook_url, + f"#{channel.name} - {WEBUI_URL}/c/{channel.id}\n\n{message.content}", + { + "action": "channel", + "message": message.content, + "title": channel.name, + "url": f"{WEBUI_URL}/c/{channel.id}", + }, + ) + + @router.post("/{id}/messages/post", response_model=Optional[MessageModel]) async def post_new_message( - id: str, form_data: MessageForm, user=Depends(get_verified_user) + id: str, + form_data: MessageForm, + background_tasks: BackgroundTasks, + user=Depends(get_verified_user), ): channel = Channels.get_channel_by_id(id) if not channel: @@ -201,24 +232,44 @@ async def post_new_message( message = Messages.insert_new_message(form_data, channel.id, user.id) if message: + event_data = { + "channel_id": channel.id, + "message_id": message.id, + "data": { + "type": "message", + "data": { + **message.model_dump(), + "user": UserNameResponse(**user.model_dump()).model_dump(), + }, + }, + "user": UserNameResponse(**user.model_dump()).model_dump(), + "channel": channel.model_dump(), + } + await sio.emit( "channel-events", - { - "channel_id": channel.id, - "message_id": message.id, - "data": { - "type": "message", - "data": { - **message.model_dump(), - "user": UserNameResponse(**user.model_dump()).model_dump(), - }, - }, - "user": UserNameResponse(**user.model_dump()).model_dump(), - "channel": channel.model_dump(), - }, + event_data, to=f"channel:{channel.id}", ) + active_session_ids = sio.manager.get_participants( + namespace="/", + room=f"channel:{channel.id}", + ) + + active_user_ids = list( + set( + [ + SESSION_POOL.get(session_id[0]) + for session_id in active_session_ids + ] + ) + ) + + background_tasks.add_task( + send_notification, channel, message, active_user_ids + ) + return MessageModel(**message.model_dump()) except Exception as e: log.exception(e) diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 3b3e75a8b..da61e7fb3 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -1,4 +1,5 @@ from typing import Optional, Union, List, Dict, Any +from open_webui.models.users import Users, UserModel from open_webui.models.groups import Groups import json @@ -93,3 +94,24 @@ def has_access( return user_id in permitted_user_ids or any( group_id in permitted_group_ids for group_id in user_group_ids ) + + +# Get all users with access to a resource +def get_users_with_access( + type: str = "write", access_control: Optional[dict] = None +) -> List[UserModel]: + if access_control is None: + return Users.get_users() + + permission_access = access_control.get(type, {}) + permitted_group_ids = permission_access.get("group_ids", []) + permitted_user_ids = permission_access.get("user_ids", []) + + user_ids_with_access = set(permitted_user_ids) + + for group_id in permitted_group_ids: + group_user_ids = Groups.get_group_user_ids_by_id(group_id) + if group_user_ids: + user_ids_with_access.update(group_user_ids) + + return Users.get_users_by_user_ids(list(user_ids_with_access)) diff --git a/backend/open_webui/utils/webhook.py b/backend/open_webui/utils/webhook.py index 942e62e85..d59244dd3 100644 --- a/backend/open_webui/utils/webhook.py +++ b/backend/open_webui/utils/webhook.py @@ -21,7 +21,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool: elif "https://discord.com/api/webhooks" in url: payload["content"] = ( message - if len(message) > 2000 + if len(message) < 2000 else f"{message[: 2000 - 20]}... (truncated)" ) # Microsoft Teams Webhooks