refac: prompts pagination
This commit is contained in:
@@ -10,7 +10,8 @@ from open_webui.models.prompt_history import PromptHistories
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, func, cast
|
||||
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
@@ -85,7 +86,18 @@ class PromptAccessResponse(PromptUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
|
||||
|
||||
class PromptListResponse(BaseModel):
|
||||
items: list[PromptUserResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class PromptAccessListResponse(BaseModel):
|
||||
items: list[PromptAccessResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class PromptForm(BaseModel):
|
||||
|
||||
command: str
|
||||
name: str # Changed from title
|
||||
content: str
|
||||
@@ -227,7 +239,109 @@ class PromptsTable:
|
||||
or has_access(user_id, permission, prompt.access_control, user_group_ids)
|
||||
]
|
||||
|
||||
def search_prompts(
|
||||
self,
|
||||
user_id: str,
|
||||
filter: dict = {},
|
||||
skip: int = 0,
|
||||
limit: int = 30,
|
||||
db: Optional[Session] = None,
|
||||
) -> PromptListResponse:
|
||||
with get_db_context(db) as db:
|
||||
from open_webui.models.users import User, UserModel
|
||||
|
||||
# Join with User table for user filtering and sorting
|
||||
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
|
||||
query = query.filter(Prompt.is_active == True)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Prompt.name.ilike(f"%{query_key}%"),
|
||||
Prompt.command.ilike(f"%{query_key}%"),
|
||||
Prompt.content.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
query = query.filter(Prompt.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
query = query.filter(Prompt.user_id != user_id)
|
||||
|
||||
# Apply access control filtering
|
||||
group_ids = filter.get("group_ids", [])
|
||||
filter_user_id = filter.get("user_id")
|
||||
|
||||
if filter_user_id:
|
||||
# User must have access: owner OR public OR explicit access
|
||||
access_conditions = [
|
||||
Prompt.user_id == filter_user_id, # Owner
|
||||
Prompt.access_control == None, # Public
|
||||
]
|
||||
query = query.filter(or_(*access_conditions))
|
||||
|
||||
tag = filter.get("tag")
|
||||
if tag:
|
||||
# Search for tag in JSON array field
|
||||
like_pattern = f'%"{tag.lower()}"%'
|
||||
tags_text = func.lower(cast(Prompt.tags, String))
|
||||
query = query.filter(tags_text.like(like_pattern))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Prompt.name.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.name.desc())
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Prompt.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.created_at.desc())
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Prompt.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.updated_at.desc())
|
||||
else:
|
||||
query = query.order_by(Prompt.updated_at.desc())
|
||||
else:
|
||||
query = query.order_by(Prompt.updated_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
total = query.count()
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
items = query.all()
|
||||
|
||||
prompts = []
|
||||
for prompt, user in items:
|
||||
prompts.append(
|
||||
PromptUserResponse(
|
||||
**PromptModel.model_validate(prompt).model_dump(),
|
||||
user=(
|
||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||
if user
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return PromptListResponse(items=prompts, total=total)
|
||||
|
||||
def update_prompt_by_command(
|
||||
|
||||
self,
|
||||
command: str,
|
||||
form_data: PromptForm,
|
||||
|
||||
@@ -5,9 +5,11 @@ from open_webui.models.prompts import (
|
||||
PromptForm,
|
||||
PromptUserResponse,
|
||||
PromptAccessResponse,
|
||||
PromptAccessListResponse,
|
||||
PromptModel,
|
||||
Prompts,
|
||||
)
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.prompt_history import (
|
||||
PromptHistories,
|
||||
PromptHistoryModel,
|
||||
@@ -34,6 +36,8 @@ class PromptMetadataForm(BaseModel):
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
|
||||
############################
|
||||
# GetPrompts
|
||||
@@ -67,26 +71,57 @@ async def get_prompt_tags(
|
||||
return sorted(list(tags))
|
||||
|
||||
|
||||
@router.get("/list", response_model=list[PromptAccessResponse])
|
||||
@router.get("/list", response_model=PromptAccessListResponse)
|
||||
async def get_prompt_list(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
query: Optional[str] = None,
|
||||
view_option: Optional[str] = None,
|
||||
tag: Optional[str] = None,
|
||||
order_by: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
prompts = Prompts.get_prompts(db=db)
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
|
||||
limit = PAGE_ITEM_COUNT
|
||||
|
||||
return [
|
||||
PromptAccessResponse(
|
||||
**prompt.model_dump(),
|
||||
write_access=(
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == prompt.user_id
|
||||
or has_access(user.id, "write", prompt.access_control, db=db)
|
||||
),
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
page = max(1, page)
|
||||
skip = (page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
filter["query"] = query
|
||||
if view_option:
|
||||
filter["view_option"] = view_option
|
||||
if tag:
|
||||
filter["tag"] = tag
|
||||
if order_by:
|
||||
filter["order_by"] = order_by
|
||||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL):
|
||||
groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
if groups:
|
||||
filter["group_ids"] = [group.id for group in groups]
|
||||
|
||||
filter["user_id"] = user.id
|
||||
|
||||
result = Prompts.search_prompts(user.id, filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
return PromptAccessListResponse(
|
||||
items=[
|
||||
PromptAccessResponse(
|
||||
**prompt.model_dump(),
|
||||
write_access=(
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == prompt.user_id
|
||||
or has_access(user.id, "write", prompt.access_control, db=db)
|
||||
),
|
||||
)
|
||||
for prompt in result.items
|
||||
],
|
||||
total=result.total,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
|
||||
Reference in New Issue
Block a user