diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 807c87dcc..a5f349278 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -58,6 +58,7 @@ from open_webui.routers import ( pipelines, tasks, auths, + channels, chats, folders, configs, @@ -737,6 +738,8 @@ app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) + +app.include_router(channels.router, prefix="/api/v1/channels", tags=["channels"]) app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) diff --git a/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py new file mode 100644 index 000000000..b203e2f4e --- /dev/null +++ b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py @@ -0,0 +1,47 @@ +"""Add channel table + +Revision ID: 57c599a3cb57 +Revises: 922e7a387820 +Create Date: 2024-12-22 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "57c599a3cb57" +down_revision = "922e7a387820" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "channel", + sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column("user_id", sa.Text()), + sa.Column("name", sa.Text()), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("access_control", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + ) + + op.create_table( + "message", + sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), + sa.Column("user_id", sa.Text()), + sa.Column("channel_id", sa.Text(), nullable=True), + sa.Column("content", sa.Text()), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column("meta", sa.JSON(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + ) + + +def downgrade(): + op.drop_table("channel") + + op.drop_table("message") diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py new file mode 100644 index 000000000..03ffd57c2 --- /dev/null +++ b/backend/open_webui/models/channels.py @@ -0,0 +1,115 @@ +import json +import time +import uuid +from typing import Optional + +from open_webui.internal.db import Base, get_db +from open_webui.models.tags import TagModel, Tag, Tags + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import or_, func, select, and_, text +from sqlalchemy.sql import exists + +#################### +# Channel DB Schema +#################### + + +class Channel(Base): + __tablename__ = "channel" + + id = Column(Text, primary_key=True) + user_id = Column(Text) + + name = Column(Text) + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + access_control = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class ChannelModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + user_id: str + + name: str + data: Optional[dict] = None + meta: Optional[dict] = None + access_control: Optional[dict] = None + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ChannelForm(BaseModel): + name: str + data: Optional[dict] = None + meta: Optional[dict] = None + access_control: Optional[dict] = None + + +class ChannelTable: + def insert_new_channel( + self, form_data: ChannelForm, user_id: str + ) -> Optional[ChannelModel]: + with get_db() as db: + new_channel = Channel( + **{ + **form_data.model_dump(), + "id": str(uuid.uuid4()), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + db.add(new_channel) + db.commit() + return new_channel + + def get_channels(self) -> list[ChannelModel]: + with get_db() as db: + channels = db.query(Channel).all() + return [ChannelModel.model_validate(channel) for channel in channels] + + def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: + with get_db() as db: + channel = db.query(Channel).filter(Channel.id == id).first() + return ChannelModel.model_validate(channel) if channel else None + + def update_channel_by_id( + self, id: str, form_data: ChannelForm + ) -> Optional[ChannelModel]: + with get_db() as db: + channel = db.query(Channel).filter(Channel.id == id).first() + if not channel: + return None + + channel.name = form_data.name + channel.data = form_data.data + channel.meta = form_data.meta + channel.access_control = form_data.access_control + channel.updated_at = int(time.time()) + + db.commit() + return ChannelModel.model_validate(channel) if channel else None + + def delete_channel_by_id(self, id: str): + with get_db() as db: + db.query(Channel).filter(Channel.id == id).delete() + db.commit() + return True + + +Channels = ChannelTable() diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py new file mode 100644 index 000000000..c9161da96 --- /dev/null +++ b/backend/open_webui/models/messages.py @@ -0,0 +1,139 @@ +import json +import time +import uuid +from typing import Optional + +from open_webui.internal.db import Base, get_db +from open_webui.models.tags import TagModel, Tag, Tags + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import or_, func, select, and_, text +from sqlalchemy.sql import exists + +#################### +# Message DB Schema +#################### + + +class Message(Base): + __tablename__ = "message" + id = Column(Text, primary_key=True) + + user_id = Column(Text) + channel_id = Column(Text, nullable=True) + + content = Column(Text) + data = Column(JSON, nullable=True) + meta = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class MessageModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + user_id: str + channel_id: Optional[str] = None + + content: str + data: Optional[dict] = None + meta: Optional[dict] = None + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class MessageForm(BaseModel): + content: str + data: Optional[dict] = None + meta: Optional[dict] = None + + +class MessageTable: + def insert_new_message( + self, form_data: MessageForm, channel_id: str, user_id: str + ) -> Optional[MessageModel]: + with get_db() as db: + id = str(uuid.uuid4()) + message = MessageModel( + **{ + "id": id, + "user_id": user_id, + "channel_id": channel_id, + "content": form_data.content, + "data": form_data.data, + "meta": form_data.meta, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Message(**message.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return MessageModel.model_validate(result) if result else None + + def get_message_by_id(self, id: str) -> Optional[MessageModel]: + with get_db() as db: + message = db.get(Message, id) + return MessageModel.model_validate(message) if message else None + + def get_messages_by_channel_id( + self, channel_id: str, skip: int = 0, limit: int = 50 + ) -> list[MessageModel]: + with get_db() as db: + all_messages = ( + db.query(Message) + .filter_by(channel_id=channel_id) + .order_by(Message.updated_at.desc()) + .limit(limit) + .offset(skip) + .all() + ) + return [MessageModel.model_validate(message) for message in all_messages] + + def get_messages_by_user_id( + self, user_id: str, skip: int = 0, limit: int = 50 + ) -> list[MessageModel]: + with get_db() as db: + all_messages = ( + db.query(Message) + .filter_by(user_id=user_id) + .order_by(Message.updated_at.desc()) + .limit(limit) + .offset(skip) + .all() + ) + return [MessageModel.model_validate(message) for message in all_messages] + + def update_message_by_id( + self, id: str, form_data: MessageForm + ) -> Optional[MessageModel]: + with get_db() as db: + message = db.get(Message, id) + message.content = form_data.content + message.data = form_data.data + message.meta = form_data.meta + message.updated_at = int(time.time()) + db.commit() + db.refresh(message) + return MessageModel.model_validate(message) if message else None + + def delete_message_by_id(self, id: str) -> bool: + with get_db() as db: + db.query(Message).filter_by(id=id).delete() + db.commit() + return True + + +Messages = MessageTable() diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py new file mode 100644 index 000000000..b4b458f25 --- /dev/null +++ b/backend/open_webui/routers/channels.py @@ -0,0 +1,102 @@ +import json +import logging +from typing import Optional + +from open_webui.models.channels import Channels, ChannelModel, ChannelForm +from open_webui.models.messages import Messages, MessageModel, MessageForm + + +from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import SRC_LOG_LEVELS +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel + + +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_permission + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +router = APIRouter() + +############################ +# GetChatList +############################ + + +@router.get("/", response_model=list[ChannelModel]) +async def get_channels(user=Depends(get_verified_user)): + return Channels.get_channels() + + +############################ +# CreateNewChannel +############################ + + +@router.post("/create", response_model=Optional[ChannelModel]) +async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)): + try: + channel = Channels.insert_new_channel(form_data, user.id) + return ChannelModel(**channel.model_dump()) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# GetChannelMessages +############################ + + +@router.post("/{id}/messages", response_model=list[MessageModel]) +async def get_channel_messages(id: str, page: int = 1, user=Depends(get_verified_user)): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if not has_permission(channel.access_control, user): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + limit = 50 + skip = (page - 1) * limit + + return Messages.get_messages_by_channel_id(id, skip, limit) + + +############################ +# PostNewMessage +############################ + + +@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) +async def post_new_message( + id: str, form_data: MessageForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if not has_permission(channel.access_control, user): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + message = Messages.insert_new_message(form_data, channel.id, user.id) + return MessageModel(**message.model_dump()) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + )