refac: prompts pagination

This commit is contained in:
Timothy Jaeryang Baek
2026-01-27 23:01:56 +04:00
parent 683438b418
commit 36766f157d
4 changed files with 298 additions and 51 deletions

View File

@@ -10,7 +10,8 @@ from open_webui.models.prompt_history import PromptHistories
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, func, cast
from open_webui.utils.access_control import has_access
@@ -85,7 +86,18 @@ class PromptAccessResponse(PromptUserResponse):
write_access: Optional[bool] = False
class PromptListResponse(BaseModel):
items: list[PromptUserResponse]
total: int
class PromptAccessListResponse(BaseModel):
items: list[PromptAccessResponse]
total: int
class PromptForm(BaseModel):
command: str
name: str # Changed from title
content: str
@@ -227,7 +239,109 @@ class PromptsTable:
or has_access(user_id, permission, prompt.access_control, user_group_ids)
]
def search_prompts(
self,
user_id: str,
filter: dict = {},
skip: int = 0,
limit: int = 30,
db: Optional[Session] = None,
) -> PromptListResponse:
with get_db_context(db) as db:
from open_webui.models.users import User, UserModel
# Join with User table for user filtering and sorting
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
query = query.filter(Prompt.is_active == True)
if filter:
query_key = filter.get("query")
if query_key:
query = query.filter(
or_(
Prompt.name.ilike(f"%{query_key}%"),
Prompt.command.ilike(f"%{query_key}%"),
Prompt.content.ilike(f"%{query_key}%"),
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
)
)
view_option = filter.get("view_option")
if view_option == "created":
query = query.filter(Prompt.user_id == user_id)
elif view_option == "shared":
query = query.filter(Prompt.user_id != user_id)
# Apply access control filtering
group_ids = filter.get("group_ids", [])
filter_user_id = filter.get("user_id")
if filter_user_id:
# User must have access: owner OR public OR explicit access
access_conditions = [
Prompt.user_id == filter_user_id, # Owner
Prompt.access_control == None, # Public
]
query = query.filter(or_(*access_conditions))
tag = filter.get("tag")
if tag:
# Search for tag in JSON array field
like_pattern = f'%"{tag.lower()}"%'
tags_text = func.lower(cast(Prompt.tags, String))
query = query.filter(tags_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
if order_by == "name":
if direction == "asc":
query = query.order_by(Prompt.name.asc())
else:
query = query.order_by(Prompt.name.desc())
elif order_by == "created_at":
if direction == "asc":
query = query.order_by(Prompt.created_at.asc())
else:
query = query.order_by(Prompt.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Prompt.updated_at.asc())
else:
query = query.order_by(Prompt.updated_at.desc())
else:
query = query.order_by(Prompt.updated_at.desc())
else:
query = query.order_by(Prompt.updated_at.desc())
# Count BEFORE pagination
total = query.count()
if skip:
query = query.offset(skip)
if limit:
query = query.limit(limit)
items = query.all()
prompts = []
for prompt, user in items:
prompts.append(
PromptUserResponse(
**PromptModel.model_validate(prompt).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
)
)
return PromptListResponse(items=prompts, total=total)
def update_prompt_by_command(
self,
command: str,
form_data: PromptForm,

View File

@@ -5,9 +5,11 @@ from open_webui.models.prompts import (
PromptForm,
PromptUserResponse,
PromptAccessResponse,
PromptAccessListResponse,
PromptModel,
Prompts,
)
from open_webui.models.groups import Groups
from open_webui.models.prompt_history import (
PromptHistories,
PromptHistoryModel,
@@ -34,6 +36,8 @@ class PromptMetadataForm(BaseModel):
router = APIRouter()
PAGE_ITEM_COUNT = 30
############################
# GetPrompts
@@ -67,26 +71,57 @@ async def get_prompt_tags(
return sorted(list(tags))
@router.get("/list", response_model=list[PromptAccessResponse])
@router.get("/list", response_model=PromptAccessListResponse)
async def get_prompt_list(
user=Depends(get_verified_user), db: Session = Depends(get_session)
query: Optional[str] = None,
view_option: Optional[str] = None,
tag: Optional[str] = None,
order_by: Optional[str] = None,
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
prompts = Prompts.get_prompts(db=db)
else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read", db=db)
limit = PAGE_ITEM_COUNT
return [
PromptAccessResponse(
**prompt.model_dump(),
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == prompt.user_id
or has_access(user.id, "write", prompt.access_control, db=db)
),
)
for prompt in prompts
]
page = max(1, page)
skip = (page - 1) * limit
filter = {}
if query:
filter["query"] = query
if view_option:
filter["view_option"] = view_option
if tag:
filter["tag"] = tag
if order_by:
filter["order_by"] = order_by
if direction:
filter["direction"] = direction
if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL):
groups = Groups.get_groups_by_member_id(user.id, db=db)
if groups:
filter["group_ids"] = [group.id for group in groups]
filter["user_id"] = user.id
result = Prompts.search_prompts(user.id, filter=filter, skip=skip, limit=limit, db=db)
return PromptAccessListResponse(
items=[
PromptAccessResponse(
**prompt.model_dump(),
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == prompt.user_id
or has_access(user.id, "write", prompt.access_control, db=db)
),
)
for prompt in result.items
],
total=result.total,
)
############################