diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 7564475d4..c7f338066 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -12,6 +12,7 @@ from open_webui.apps.webui.routers import ( chats, folders, configs, + groups, files, functions, memories, @@ -85,7 +86,11 @@ from open_webui.utils.payload import ( from open_webui.utils.tools import get_tools -app = FastAPI(docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None) +app = FastAPI( + docs_url="/docs" if ENV == "dev" else None, + openapi_url="/openapi.json" if ENV == "dev" else None, + redoc_url=None, +) log = logging.getLogger(__name__) @@ -161,13 +166,15 @@ app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(tools.router, prefix="/tools", tags=["tools"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/folders", tags=["folders"]) + +app.include_router(groups.router, prefix="/groups", tags=["groups"]) +app.include_router(files.router, prefix="/files", tags=["files"]) +app.include_router(functions.router, prefix="/functions", tags=["functions"]) app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) -app.include_router(folders.router, prefix="/folders", tags=["folders"]) -app.include_router(files.router, prefix="/files", tags=["files"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/open_webui/apps/webui/models/groups.py b/backend/open_webui/apps/webui/models/groups.py new file mode 100644 index 000000000..e8cfea00f --- /dev/null +++ b/backend/open_webui/apps/webui/models/groups.py @@ -0,0 +1,169 @@ +import json +import logging +import time +from typing import Optional +import uuid + +from open_webui.apps.webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS + +from open_webui.apps.webui.models.files import FileMetadataResponse + + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, JSON + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# UserGroup DB Schema +#################### + + +class Group(Base): + __tablename__ = "group" + + id = Column(Text, unique=True, primary_key=True) + user_id = Column(Text) + + name = Column(Text) + description = Column(Text) + meta = Column(JSON, nullable=True) + + permissions = Column(JSON, nullable=True) + user_ids = Column(JSON, nullable=True) + admin_ids = Column(JSON, nullable=True) + + created_at = Column(BigInteger) + updated_at = Column(BigInteger) + + +class GroupModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + user_id: str + + name: str + description: str + meta: Optional[dict] = None + + permissions: Optional[dict] = None + user_ids: list[str] = [] + admin_ids: list[str] = [] + + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class GroupResponse(BaseModel): + id: str + user_id: str + name: str + description: str + permissions: Optional[dict] = None + meta: Optional[dict] = None + user_ids: list[str] = [] + admin_ids: list[str] = [] + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + +class GroupForm(BaseModel): + name: str + description: str + + +class GroupUpdateForm(GroupForm): + permissions: Optional[dict] = None + user_ids: Optional[list[str]] = None + admin_ids: Optional[list[str]] = None + + +class GroupTable: + def insert_new_group( + self, user_id: str, form_data: GroupForm + ) -> Optional[GroupModel]: + with get_db() as db: + group = GroupModel( + **{ + **form_data.model_dump(), + "id": str(uuid.uuid4()), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + try: + result = Groups(**group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return GroupModel.model_validate(result) + else: + return None + + except Exception: + return None + + def get_groups(self) -> list[GroupModel]: + with get_db() as db: + return [ + GroupModel.model_validate(group) + for group in db.query(Groups).order_by(Groups.updated_at.desc()).all() + ] + + def get_group_by_id(self, id: str) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Groups).filter_by(id=id).first() + return GroupModel.model_validate(group) if group else None + except Exception: + return None + + def update_group_by_id( + self, id: str, form_data: GroupUpdateForm, overwrite: bool = False + ) -> Optional[GroupModel]: + try: + with get_db() as db: + db.query(Groups).filter_by(id=id).update( + { + **form_data.model_dump(exclude_none=True), + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_group_by_id(id=id) + except Exception as e: + log.exception(e) + return None + + def delete_group_by_id(self, id: str) -> bool: + try: + with get_db() as db: + db.query(Groups).filter_by(id=id).delete() + db.commit() + return True + except Exception: + return False + + def delete_all_groups(self) -> bool: + with get_db() as db: + try: + db.query(Groups).delete() + db.commit() + + return True + except Exception: + return False + + +Groups = GroupTable() diff --git a/backend/open_webui/apps/webui/routers/groups.py b/backend/open_webui/apps/webui/routers/groups.py new file mode 100644 index 000000000..984cece57 --- /dev/null +++ b/backend/open_webui/apps/webui/routers/groups.py @@ -0,0 +1,117 @@ +import os +from pathlib import Path +from typing import Optional + +from open_webui.apps.webui.models.groups import ( + Groups, + GroupForm, + GroupUpdateForm, + GroupResponse, +) + +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.utils.utils import get_admin_user, get_verified_user + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=list[GroupResponse]) +async def get_groups(user=Depends(get_admin_user)): + return Groups.get_groups() + + +############################ +# CreateNewGroup +############################ + + +@router.post("/create", response_model=Optional[GroupResponse]) +async def create_new_function(form_data: GroupForm, user=Depends(get_admin_user)): + try: + group = Groups.insert_new_group(user.id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# GetGroupById +############################ + + +@router.get("/id/{id}", response_model=Optional[GroupResponse]) +async def get_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateGroupById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[GroupResponse]) +async def update_group_by_id( + id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) +): + try: + group = Groups.update_group_by_id(id, form_data) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteGroupById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_group_by_id(id: str, user=Depends(get_admin_user)): + try: + result = Groups.delete_group_by_id(id) + if result: + return result + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting group"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + )