diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py index dc91f6dc5..da3e70011 100644 --- a/backend/open_webui/models/skills.py +++ b/backend/open_webui/models/skills.py @@ -9,7 +9,7 @@ from open_webui.models.groups import Groups from open_webui.models.access_grants import AccessGrantModel, AccessGrants from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import BigInteger, Boolean, Column, String, Text +from sqlalchemy import BigInteger, Boolean, Column, String, Text, or_ log = logging.getLogger(__name__) @@ -95,6 +95,11 @@ class SkillForm(BaseModel): access_grants: Optional[list[dict]] = None +class SkillListResponse(BaseModel): + items: list[SkillAccessResponse] = [] + total: int = 0 + + class SkillsTable: def _get_access_grants( self, skill_id: str, db: Optional[Session] = None @@ -200,6 +205,88 @@ class SkillsTable: ) ] + def search_skills( + self, + user_id: str, + filter: dict, + skip: int = 0, + limit: int = 30, + db: Optional[Session] = None, + ) -> SkillListResponse: + try: + with get_db_context(db) as db: + query = db.query(Skill) + + query_key = filter.get("query") + if query_key: + query = query.filter( + or_( + Skill.name.ilike(f"%{query_key}%"), + Skill.description.ilike(f"%{query_key}%"), + Skill.id.ilike(f"%{query_key}%"), + ) + ) + + # Only active skills + query = query.filter(Skill.is_active == True) + + query = query.order_by(Skill.updated_at.desc()) + + # Apply access control if not admin bypass + if "user_id" in filter: + user_group_ids = { + group.id + for group in Groups.get_groups_by_member_id( + filter["user_id"], db=db + ) + } + all_results = query.all() + accessible = [ + s + for s in all_results + if s.user_id == filter["user_id"] + or AccessGrants.has_access( + user_id=filter["user_id"], + resource_type="skill", + resource_id=s.id, + permission="read", + user_group_ids=user_group_ids, + db=db, + ) + ] + total = len(accessible) + items = accessible[skip : skip + limit] if limit else accessible[skip:] + else: + total = query.count() + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + items = query.all() + + user_ids = list(set(s.user_id for s in items)) + users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] + users_dict = {u.id: u for u in users} + + skill_responses = [] + for skill in items: + user = users_dict.get(skill.user_id) + skill_model = self._to_skill_model(skill, db=db) + skill_responses.append( + SkillAccessResponse( + **SkillUserResponse( + **skill_model.model_dump(), + user=user.model_dump() if user else None, + ).model_dump(), + write_access=False, + ) + ) + + return SkillListResponse(items=skill_responses, total=total) + except Exception as e: + log.exception(f"Error searching skills: {e}") + return SkillListResponse(items=[], total=0) + def update_skill_by_id( self, id: str, updated: dict, db: Optional[Session] = None ) -> Optional[SkillModel]: diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py index 92e560c73..e15226766 100644 --- a/backend/open_webui/routers/skills.py +++ b/backend/open_webui/routers/skills.py @@ -14,6 +14,7 @@ from open_webui.models.skills import ( SkillResponse, SkillUserResponse, SkillAccessResponse, + SkillListResponse, Skills, ) from open_webui.models.access_grants import AccessGrants @@ -26,6 +27,7 @@ from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) +PAGE_ITEM_COUNT = 30 router = APIRouter() @@ -98,6 +100,36 @@ async def get_skill_list( ] +############################ +# SearchSkills +############################ + + +@router.get("/search", response_model=SkillListResponse) +async def search_skills( + query: Optional[str] = None, + page: Optional[int] = 1, + user=Depends(get_verified_user), + db: Session = Depends(get_session), +): + page = max(page, 1) + limit = PAGE_ITEM_COUNT + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + + if not (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL): + filter["user_id"] = user.id + + result = Skills.search_skills( + user.id, filter=filter, skip=skip, limit=limit, db=db + ) + + return result + + ############################ # ExportSkills ############################ diff --git a/src/lib/apis/skills/index.ts b/src/lib/apis/skills/index.ts index 52d14cbbd..8477b4a6f 100644 --- a/src/lib/apis/skills/index.ts +++ b/src/lib/apis/skills/index.ts @@ -93,6 +93,45 @@ export const getSkillList = async (token: string = '') => { return res; }; +export const searchSkills = async ( + token: string = '', + query: string | null = null, + page: number | null = null +) => { + let error = null; + + const searchParams = new URLSearchParams(); + if (query) searchParams.append('query', query); + if (page) searchParams.append('page', page.toString()); + + const res = await fetch(`${WEBUI_API_BASE_URL}/skills/search?${searchParams.toString()}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const exportSkills = async (token: string = '') => { let error = null; diff --git a/src/lib/components/chat/MessageInput/Commands/Skills.svelte b/src/lib/components/chat/MessageInput/Commands/Skills.svelte index 9bac71959..9454b950a 100644 --- a/src/lib/components/chat/MessageInput/Commands/Skills.svelte +++ b/src/lib/components/chat/MessageInput/Commands/Skills.svelte @@ -1,9 +1,6 @@