mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	enh: channel notification
This commit is contained in:
		
							parent
							
								
									0d7d6899b9
								
							
						
					
					
						commit
						d701b69e05
					
				@ -146,6 +146,13 @@ class GroupTable:
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
 | 
			
		||||
        group = self.get_group_by_id(id)
 | 
			
		||||
        if group:
 | 
			
		||||
            return group.user_ids
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def update_group_by_id(
 | 
			
		||||
        self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
 | 
			
		||||
    ) -> Optional[GroupModel]:
 | 
			
		||||
 | 
			
		||||
@ -154,13 +154,25 @@ class UsersTable:
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
 | 
			
		||||
    def get_users(
 | 
			
		||||
        self, skip: Optional[int] = None, limit: Optional[int] = None
 | 
			
		||||
    ) -> list[UserModel]:
 | 
			
		||||
        with get_db() as db:
 | 
			
		||||
            users = (
 | 
			
		||||
                db.query(User)
 | 
			
		||||
                # .offset(skip).limit(limit)
 | 
			
		||||
                .all()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            query = db.query(User).order_by(User.created_at.desc())
 | 
			
		||||
 | 
			
		||||
            if skip:
 | 
			
		||||
                query = query.offset(skip)
 | 
			
		||||
            if limit:
 | 
			
		||||
                query = query.limit(limit)
 | 
			
		||||
 | 
			
		||||
            users = query.all()
 | 
			
		||||
 | 
			
		||||
            return [UserModel.model_validate(user) for user in users]
 | 
			
		||||
 | 
			
		||||
    def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
 | 
			
		||||
        with get_db() as db:
 | 
			
		||||
            users = db.query(User).filter(User.id.in_(user_ids)).all()
 | 
			
		||||
            return [UserModel.model_validate(user) for user in users]
 | 
			
		||||
 | 
			
		||||
    def get_num_users(self) -> Optional[int]:
 | 
			
		||||
@ -179,11 +191,15 @@ class UsersTable:
 | 
			
		||||
        try:
 | 
			
		||||
            with get_db() as db:
 | 
			
		||||
                user = db.query(User).filter_by(id=id).first()
 | 
			
		||||
                return (
 | 
			
		||||
                    user.settings.get("ui", {})
 | 
			
		||||
                    .get("notifications", {})
 | 
			
		||||
                    .get("webhook_url", None)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if user.settings is None:
 | 
			
		||||
                    return None
 | 
			
		||||
                else:
 | 
			
		||||
                    return (
 | 
			
		||||
                        user.settings.get("ui", {})
 | 
			
		||||
                        .get("notifications", {})
 | 
			
		||||
                        .get("webhook_url", None)
 | 
			
		||||
                    )
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,11 +3,11 @@ import logging
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
 | 
			
		||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from open_webui.socket.main import sio
 | 
			
		||||
from open_webui.socket.main import sio, SESSION_POOL
 | 
			
		||||
from open_webui.models.users import Users, UserNameResponse
 | 
			
		||||
 | 
			
		||||
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
 | 
			
		||||
@ -16,11 +16,12 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
 | 
			
		||||
 | 
			
		||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 | 
			
		||||
from open_webui.constants import ERROR_MESSAGES
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS
 | 
			
		||||
from open_webui.env import SRC_LOG_LEVELS, WEBUI_URL
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from open_webui.utils.auth import get_admin_user, get_verified_user
 | 
			
		||||
from open_webui.utils.access_control import has_access
 | 
			
		||||
from open_webui.utils.access_control import has_access, get_users_with_access
 | 
			
		||||
from open_webui.utils.webhook import post_webhook
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
 | 
			
		||||
@ -180,9 +181,39 @@ async def get_channel_messages(
 | 
			
		||||
############################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def send_notification(channel, message, active_user_ids):
 | 
			
		||||
 | 
			
		||||
    print(f"Sending notification to {channel=}, {message=}, {active_user_ids=}")
 | 
			
		||||
    users = get_users_with_access("read", channel.access_control)
 | 
			
		||||
 | 
			
		||||
    for user in users:
 | 
			
		||||
        if user.id in active_user_ids:
 | 
			
		||||
            continue
 | 
			
		||||
        else:
 | 
			
		||||
            if user.settings:
 | 
			
		||||
                webhook_url = user.settings.ui.get("notifications", {}).get(
 | 
			
		||||
                    "webhook_url", None
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if webhook_url:
 | 
			
		||||
                    post_webhook(
 | 
			
		||||
                        webhook_url,
 | 
			
		||||
                        f"#{channel.name} - {WEBUI_URL}/c/{channel.id}\n\n{message.content}",
 | 
			
		||||
                        {
 | 
			
		||||
                            "action": "channel",
 | 
			
		||||
                            "message": message.content,
 | 
			
		||||
                            "title": channel.name,
 | 
			
		||||
                            "url": f"{WEBUI_URL}/c/{channel.id}",
 | 
			
		||||
                        },
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
 | 
			
		||||
async def post_new_message(
 | 
			
		||||
    id: str, form_data: MessageForm, user=Depends(get_verified_user)
 | 
			
		||||
    id: str,
 | 
			
		||||
    form_data: MessageForm,
 | 
			
		||||
    background_tasks: BackgroundTasks,
 | 
			
		||||
    user=Depends(get_verified_user),
 | 
			
		||||
):
 | 
			
		||||
    channel = Channels.get_channel_by_id(id)
 | 
			
		||||
    if not channel:
 | 
			
		||||
@ -201,24 +232,44 @@ async def post_new_message(
 | 
			
		||||
        message = Messages.insert_new_message(form_data, channel.id, user.id)
 | 
			
		||||
 | 
			
		||||
        if message:
 | 
			
		||||
            event_data = {
 | 
			
		||||
                "channel_id": channel.id,
 | 
			
		||||
                "message_id": message.id,
 | 
			
		||||
                "data": {
 | 
			
		||||
                    "type": "message",
 | 
			
		||||
                    "data": {
 | 
			
		||||
                        **message.model_dump(),
 | 
			
		||||
                        "user": UserNameResponse(**user.model_dump()).model_dump(),
 | 
			
		||||
                    },
 | 
			
		||||
                },
 | 
			
		||||
                "user": UserNameResponse(**user.model_dump()).model_dump(),
 | 
			
		||||
                "channel": channel.model_dump(),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            await sio.emit(
 | 
			
		||||
                "channel-events",
 | 
			
		||||
                {
 | 
			
		||||
                    "channel_id": channel.id,
 | 
			
		||||
                    "message_id": message.id,
 | 
			
		||||
                    "data": {
 | 
			
		||||
                        "type": "message",
 | 
			
		||||
                        "data": {
 | 
			
		||||
                            **message.model_dump(),
 | 
			
		||||
                            "user": UserNameResponse(**user.model_dump()).model_dump(),
 | 
			
		||||
                        },
 | 
			
		||||
                    },
 | 
			
		||||
                    "user": UserNameResponse(**user.model_dump()).model_dump(),
 | 
			
		||||
                    "channel": channel.model_dump(),
 | 
			
		||||
                },
 | 
			
		||||
                event_data,
 | 
			
		||||
                to=f"channel:{channel.id}",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            active_session_ids = sio.manager.get_participants(
 | 
			
		||||
                namespace="/",
 | 
			
		||||
                room=f"channel:{channel.id}",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            active_user_ids = list(
 | 
			
		||||
                set(
 | 
			
		||||
                    [
 | 
			
		||||
                        SESSION_POOL.get(session_id[0])
 | 
			
		||||
                        for session_id in active_session_ids
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            background_tasks.add_task(
 | 
			
		||||
                send_notification, channel, message, active_user_ids
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return MessageModel(**message.model_dump())
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        log.exception(e)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
from typing import Optional, Union, List, Dict, Any
 | 
			
		||||
from open_webui.models.users import Users, UserModel
 | 
			
		||||
from open_webui.models.groups import Groups
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
@ -93,3 +94,24 @@ def has_access(
 | 
			
		||||
    return user_id in permitted_user_ids or any(
 | 
			
		||||
        group_id in permitted_group_ids for group_id in user_group_ids
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Get all users with access to a resource
 | 
			
		||||
def get_users_with_access(
 | 
			
		||||
    type: str = "write", access_control: Optional[dict] = None
 | 
			
		||||
) -> List[UserModel]:
 | 
			
		||||
    if access_control is None:
 | 
			
		||||
        return Users.get_users()
 | 
			
		||||
 | 
			
		||||
    permission_access = access_control.get(type, {})
 | 
			
		||||
    permitted_group_ids = permission_access.get("group_ids", [])
 | 
			
		||||
    permitted_user_ids = permission_access.get("user_ids", [])
 | 
			
		||||
 | 
			
		||||
    user_ids_with_access = set(permitted_user_ids)
 | 
			
		||||
 | 
			
		||||
    for group_id in permitted_group_ids:
 | 
			
		||||
        group_user_ids = Groups.get_group_user_ids_by_id(group_id)
 | 
			
		||||
        if group_user_ids:
 | 
			
		||||
            user_ids_with_access.update(group_user_ids)
 | 
			
		||||
 | 
			
		||||
    return Users.get_users_by_user_ids(list(user_ids_with_access))
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
 | 
			
		||||
        elif "https://discord.com/api/webhooks" in url:
 | 
			
		||||
            payload["content"] = (
 | 
			
		||||
                message
 | 
			
		||||
                if len(message) > 2000
 | 
			
		||||
                if len(message) < 2000
 | 
			
		||||
                else f"{message[: 2000 - 20]}... (truncated)"
 | 
			
		||||
            )
 | 
			
		||||
        # Microsoft Teams Webhooks
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user