mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: channel socket integration
This commit is contained in:
@@ -4,8 +4,7 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
@@ -85,6 +84,17 @@ class ChannelTable:
|
||||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, permission: str = "read"
|
||||
) -> list[ChannelModel]:
|
||||
channels = self.get_channels()
|
||||
return [
|
||||
channel
|
||||
for channel in channels
|
||||
if channel.user_id == user_id
|
||||
or has_access(user_id, permission, channel.access_control)
|
||||
]
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
|
||||
@@ -95,7 +95,7 @@ class MessageTable:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id)
|
||||
.order_by(Message.updated_at.desc())
|
||||
.order_by(Message.updated_at.asc())
|
||||
.limit(limit)
|
||||
.offset(skip)
|
||||
.all()
|
||||
@@ -109,7 +109,7 @@ class MessageTable:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Message.updated_at.desc())
|
||||
.order_by(Message.updated_at.asc())
|
||||
.limit(limit)
|
||||
.offset(skip)
|
||||
.all()
|
||||
|
||||
@@ -2,6 +2,12 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.socket.main import sio
|
||||
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
|
||||
from open_webui.models.messages import Messages, MessageModel, MessageForm
|
||||
|
||||
@@ -9,12 +15,10 @@ from open_webui.models.messages import Messages, MessageModel, MessageForm
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_permission
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -53,7 +57,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/messages", response_model=list[MessageModel])
|
||||
@router.get("/{id}/messages", response_model=list[MessageModel])
|
||||
async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
@@ -61,7 +65,7 @@ async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if not has_permission(channel.access_control, user):
|
||||
if not has_access(user.id, type="read", access_control=channel.access_control):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
@@ -87,13 +91,25 @@ async def post_new_message(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if not has_permission(channel.access_control, user):
|
||||
if not has_access(user.id, type="read", access_control=channel.access_control):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
||||
|
||||
if message:
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {"message": message.model_dump()},
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return MessageModel(**message.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
import time
|
||||
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.channels import Channels
|
||||
from open_webui.env import (
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
@@ -162,7 +163,6 @@ async def connect(sid, environ, auth):
|
||||
|
||||
@sio.on("user-join")
|
||||
async def user_join(sid, data):
|
||||
# print("user-join", sid, data)
|
||||
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
@@ -182,6 +182,12 @@ async def user_join(sid, data):
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
# Join all the channels
|
||||
channels = Channels.get_channels_by_user_id(user.id)
|
||||
log.debug(f"{channels=}")
|
||||
for channel in channels:
|
||||
await sio.enter_room(sid, f"channel:{channel.id}")
|
||||
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
|
||||
Reference in New Issue
Block a user