refac: db group
This commit is contained in:
@@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
and_,
|
||||
func,
|
||||
ForeignKey,
|
||||
cast,
|
||||
or_,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -41,7 +52,6 @@ class Group(Base):
|
||||
|
||||
|
||||
class GroupModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
@@ -56,6 +66,8 @@ class GroupModel(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GroupMember(Base):
|
||||
__tablename__ = "group_member"
|
||||
@@ -84,17 +96,8 @@ class GroupMemberModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
description: str
|
||||
permissions: Optional[dict] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
class GroupResponse(GroupModel):
|
||||
member_count: Optional[int] = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class GroupForm(BaseModel):
|
||||
@@ -112,6 +115,11 @@ class GroupUpdateForm(GroupForm):
|
||||
pass
|
||||
|
||||
|
||||
class GroupListResponse(BaseModel):
|
||||
items: list[GroupResponse] = []
|
||||
total: int = 0
|
||||
|
||||
|
||||
class GroupTable:
|
||||
def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm
|
||||
@@ -140,13 +148,87 @@ class GroupTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_groups(self) -> list[GroupModel]:
|
||||
def get_all_groups(self) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
return [GroupModel.model_validate(group) for group in groups]
|
||||
|
||||
def get_groups(self, filter) -> list[GroupResponse]:
|
||||
with get_db() as db:
|
||||
query = db.query(Group)
|
||||
|
||||
if filter:
|
||||
if "query" in filter:
|
||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||
if "member_id" in filter:
|
||||
query = query.join(
|
||||
GroupMember, GroupMember.group_id == Group.id
|
||||
).filter(GroupMember.user_id == filter["member_id"])
|
||||
|
||||
if "share" in filter:
|
||||
share_value = filter["share"]
|
||||
json_share = Group.data["config"]["share"].as_boolean()
|
||||
|
||||
if share_value:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Group.data.is_(None),
|
||||
json_share.is_(None),
|
||||
json_share == True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(
|
||||
and_(Group.data.isnot(None), json_share == False)
|
||||
)
|
||||
groups = query.order_by(Group.updated_at.desc()).all()
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
GroupResponse.model_validate(
|
||||
{
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
"member_count": self.get_group_member_count_by_id(group.id),
|
||||
}
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def search_groups(
|
||||
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
||||
) -> GroupListResponse:
|
||||
with get_db() as db:
|
||||
query = db.query(Group)
|
||||
|
||||
if filter:
|
||||
if "query" in filter:
|
||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||
if "member_id" in filter:
|
||||
query = query.join(
|
||||
GroupMember, GroupMember.group_id == Group.id
|
||||
).filter(GroupMember.user_id == filter["member_id"])
|
||||
|
||||
if "share" in filter:
|
||||
# 'share' is stored in data JSON, support both sqlite and postgres
|
||||
share_value = filter["share"]
|
||||
print("Filtering by share:", share_value)
|
||||
query = query.filter(
|
||||
Group.data.op("->>")("share") == str(share_value)
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
query = query.order_by(Group.updated_at.desc())
|
||||
groups = query.offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"items": [
|
||||
GroupResponse.model_validate(
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
member_count=self.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
for group in groups
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
|
||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
@@ -293,7 +375,7 @@ class GroupTable:
|
||||
) -> list[GroupModel]:
|
||||
|
||||
# check for existing groups
|
||||
existing_groups = self.get_groups()
|
||||
existing_groups = self.get_all_groups()
|
||||
existing_group_names = {group.name for group in existing_groups}
|
||||
|
||||
new_groups = []
|
||||
|
||||
Reference in New Issue
Block a user