From 6ebf027613fb958f316254577852709e5360906d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 16 Nov 2024 18:00:49 -0800 Subject: [PATCH] refac: prompts access control --- .../open_webui/apps/webui/routers/prompts.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index fe146d49c..ec6593291 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -4,6 +4,7 @@ from open_webui.apps.webui.models.prompts import PromptForm, PromptModel, Prompt from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, status from open_webui.utils.utils import get_admin_user, get_verified_user +from open_webui.utils.access_control import has_access router = APIRouter() @@ -65,7 +66,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: - return prompt + if ( + user.role == "admin" + or prompt.user_id == user.id + or has_access(user.id, "read", prompt.access_control) + ): + return prompt else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -84,6 +90,19 @@ async def update_prompt_by_command( form_data: PromptForm, user=Depends(get_verified_user), ): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt @@ -101,5 +120,18 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): + prompt = Prompts.get_prompt_by_command(f"/{command}") + if not prompt: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if prompt.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + result = Prompts.delete_prompt_by_command(f"/{command}") return result