open-webui/backend/open_webui/models/groups.py
Timothy Jaeryang Baek 45f4bc18f8 refac: access controls
2025-01-20 23:20:47 -08:00

212 lines
5.6 KiB
Python

import json
import logging
import time
from typing import Optional
import uuid
from open_webui.internal.db import Base, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# UserGroup DB Schema
####################
class Group(Base):
__tablename__ = "group"
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
name = Column(Text)
description = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
name: str
description: str
data: Optional[dict] = None
meta: Optional[dict] = None
permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
####################
# Forms
####################
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
name: str
description: str
permissions: Optional[dict] = None
class GroupUpdateForm(GroupForm):
user_ids: Optional[list[str]] = None
class GroupTable:
def insert_new_group(
self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]:
with get_db() as db:
group = GroupModel(
**{
**form_data.model_dump(exclude_none=True),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Group(**group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return GroupModel.model_validate(result)
else:
return None
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_group_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_group_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Group).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_groups(self) -> bool:
with get_db() as db:
try:
db.query(Group).delete()
db.commit()
return True
except Exception:
return False
def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db:
try:
groups = self.get_groups_by_member_id(user_id)
for group in groups:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except Exception:
return False
Groups = GroupTable()