From 73fe77c2dadba5aaed8025f51f57ddeb7caeec1e Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 16 Nov 2024 17:09:15 -0800 Subject: [PATCH] enh: access control --- .../open_webui/apps/webui/models/prompts.py | 60 +----- backend/open_webui/apps/webui/models/tools.py | 15 ++ .../open_webui/apps/webui/routers/prompts.py | 23 +- .../open_webui/apps/webui/routers/tools.py | 71 +++--- src/lib/apis/prompts/index.ts | 33 +++ src/lib/components/workspace/Prompts.svelte | 204 +++++++++--------- src/lib/components/workspace/Tools.svelte | 145 +++++++------ .../(app)/workspace/prompts/+page.svelte | 16 +- src/routes/(app)/workspace/tools/+page.svelte | 14 +- 9 files changed, 304 insertions(+), 277 deletions(-) diff --git a/backend/open_webui/apps/webui/models/prompts.py b/backend/open_webui/apps/webui/models/prompts.py index 7f96d5374..4b9953674 100644 --- a/backend/open_webui/apps/webui/models/prompts.py +++ b/backend/open_webui/apps/webui/models/prompts.py @@ -7,6 +7,8 @@ from open_webui.apps.webui.models.groups import Groups from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from open_webui.utils.access_control import has_access + #################### # Prompts DB Schema #################### @@ -107,58 +109,12 @@ class PromptsTable: ) -> list[PromptModel]: prompts = self.get_prompts() - groups = Groups.get_groups_by_member_id(user_id) - group_ids = [group.id for group in groups] - - if permission == "write": - return [ - prompt - for prompt in prompts - if prompt.user_id == user_id - or ( - prompt.access_control - and ( - any( - group_id - in prompt.access_control.get(permission, {}).get( - "group_ids", [] - ) - for group_id in group_ids - ) - or ( - user_id - in prompt.access_control.get(permission, {}).get( - "user_ids", [] - ) - ) - ) - ) - ] - elif permission == "read": - return [ - prompt - for prompt in prompts - if prompt.user_id == user_id - or prompt.access_control is None - or ( - prompt.access_control - and ( - any( - prompt.access_control.get(permission, {}).get( - "group_ids", [] - ) - in group_id - for group_id in group_ids - ) - or ( - user_id - in prompt.access_control.get(permission, {}).get( - "user_ids", [] - ) - ) - ) - ) - ] + return [ + prompt + for prompt in prompts + if prompt.user_id == user_id + or has_access(user_id, permission, prompt.access_control) + ] def update_prompt_by_command( self, command: str, form_data: PromptForm diff --git a/backend/open_webui/apps/webui/models/tools.py b/backend/open_webui/apps/webui/models/tools.py index 1c7089348..ee2de91e7 100644 --- a/backend/open_webui/apps/webui/models/tools.py +++ b/backend/open_webui/apps/webui/models/tools.py @@ -8,6 +8,9 @@ from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON +from open_webui.utils.access_control import has_access + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -133,6 +136,18 @@ class ToolsTable: with get_db() as db: return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + def get_tools_by_user_id( + self, user_id: str, permission: str = "write" + ) -> list[ToolModel]: + tools = self.get_tools() + + return [ + tool + for tool in tools + if tool.user_id == user_id + or has_access(tool.access_control, user_id, permission) + ] + def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: with get_db() as db: diff --git a/backend/open_webui/apps/webui/routers/prompts.py b/backend/open_webui/apps/webui/routers/prompts.py index 593c643b9..fe146d49c 100644 --- a/backend/open_webui/apps/webui/routers/prompts.py +++ b/backend/open_webui/apps/webui/routers/prompts.py @@ -14,7 +14,22 @@ router = APIRouter() @router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): - return Prompts.get_prompts() + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "read") + + return prompts + + +@router.get("/list", response_model=list[PromptModel]) +async def get_prompt_list(user=Depends(get_verified_user)): + if user.role == "admin": + prompts = Prompts.get_prompts() + else: + prompts = Prompts.get_prompts_by_user_id(user.id, "write") + + return prompts ############################ @@ -23,7 +38,7 @@ async def get_prompts(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): +async def create_new_prompt(form_data: PromptForm, user=Depends(get_verified_user)): prompt = Prompts.get_prompt_by_command(form_data.command) if prompt is None: prompt = Prompts.insert_new_prompt(user.id, form_data) @@ -67,7 +82,7 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): async def update_prompt_by_command( command: str, form_data: PromptForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: @@ -85,6 +100,6 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): +async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/open_webui/apps/webui/routers/tools.py b/backend/open_webui/apps/webui/routers/tools.py index d1ad89dea..2db982d88 100644 --- a/backend/open_webui/apps/webui/routers/tools.py +++ b/backend/open_webui/apps/webui/routers/tools.py @@ -14,37 +14,54 @@ from open_webui.utils.utils import get_admin_user, get_verified_user router = APIRouter() ############################ -# GetToolkits +# GetTools ############################ @router.get("/", response_model=list[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def get_tools(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "read") + return tools ############################ -# ExportToolKits +# GetToolList +############################ + + +@router.get("/list", response_model=list[ToolResponse]) +async def get_tool_list(user=Depends(get_verified_user)): + if user.role == "admin": + tools = Tools.get_tools() + else: + tools = Tools.get_tools_by_user_id(user.id, "write") + return tools + + +############################ +# ExportTools ############################ @router.get("/export", response_model=list[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] - return toolkits +async def export_tools(user=Depends(get_admin_user)): + tools = Tools.get_tools() + return tools ############################ -# CreateNewToolKit +# CreateNewTools ############################ @router.post("/create", response_model=Optional[ToolResponse]) -async def create_new_toolkit( +async def create_new_tools( request: Request, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -93,12 +110,12 @@ async def create_new_toolkit( ############################ -# GetToolkitById +# GetToolsById ############################ @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): +async def get_tools_by_id(id: str, user=Depends(get_verified_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: @@ -111,16 +128,16 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ############################ -# UpdateToolkitById +# UpdateToolsById ############################ @router.post("/id/{id}/update", response_model=Optional[ToolModel]) -async def update_toolkit_by_id( +async def update_tools_by_id( request: Request, id: str, form_data: ToolForm, - user=Depends(get_admin_user), + user=Depends(get_verified_user), ): try: form_data.content = replace_imports(form_data.content) @@ -158,12 +175,14 @@ async def update_toolkit_by_id( ############################ -# DeleteToolkitById +# DeleteToolsById ############################ @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): +async def delete_tools_by_id( + request: Request, id: str, user=Depends(get_verified_user) +): result = Tools.delete_tool_by_id(id) if result: @@ -180,7 +199,7 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin @router.get("/id/{id}/valves", response_model=Optional[dict]) -async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): +async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: try: @@ -204,8 +223,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)): @router.get("/id/{id}/valves/spec", response_model=Optional[dict]) -async def get_toolkit_valves_spec_by_id( - request: Request, id: str, user=Depends(get_admin_user) +async def get_tools_valves_spec_by_id( + request: Request, id: str, user=Depends(get_verified_user) ): toolkit = Tools.get_tool_by_id(id) if toolkit: @@ -232,8 +251,8 @@ async def get_toolkit_valves_spec_by_id( @router.post("/id/{id}/valves/update", response_model=Optional[dict]) -async def update_toolkit_valves_by_id( - request: Request, id: str, form_data: dict, user=Depends(get_admin_user) +async def update_tools_valves_by_id( + request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): toolkit = Tools.get_tool_by_id(id) if toolkit: @@ -276,7 +295,7 @@ async def update_toolkit_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) -async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)): +async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): toolkit = Tools.get_tool_by_id(id) if toolkit: try: @@ -295,7 +314,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user) @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict]) -async def get_toolkit_user_valves_spec_by_id( +async def get_tools_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): toolkit = Tools.get_tool_by_id(id) @@ -318,7 +337,7 @@ async def get_toolkit_user_valves_spec_by_id( @router.post("/id/{id}/valves/user/update", response_model=Optional[dict]) -async def update_toolkit_user_valves_by_id( +async def update_tools_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): toolkit = Tools.get_tool_by_id(id) diff --git a/src/lib/apis/prompts/index.ts b/src/lib/apis/prompts/index.ts index ca9c7d543..e762d9230 100644 --- a/src/lib/apis/prompts/index.ts +++ b/src/lib/apis/prompts/index.ts @@ -69,6 +69,39 @@ export const getPrompts = async (token: string = '') => { return res; }; + +export const getPromptList = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/prompts/list`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + export const getPromptByCommand = async (token: string, command: string) => { let error = null; diff --git a/src/lib/components/workspace/Prompts.svelte b/src/lib/components/workspace/Prompts.svelte index 1fed0b7b8..f7bc08f76 100644 --- a/src/lib/components/workspace/Prompts.svelte +++ b/src/lib/components/workspace/Prompts.svelte @@ -3,11 +3,17 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; - import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, config, prompts } from '$lib/stores'; - import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; - import { error } from '@sveltejs/kit'; import { goto } from '$app/navigation'; + import { onMount, getContext } from 'svelte'; + import { WEBUI_NAME, config, prompts as _prompts, user } from '$lib/stores'; + + import { + createNewPrompt, + deletePromptByCommand, + getPrompts, + getPromptList + } from '$lib/apis/prompts'; + import PromptMenu from './Prompts/PromptMenu.svelte'; import EllipsisHorizontal from '../icons/EllipsisHorizontal.svelte'; import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; @@ -16,16 +22,18 @@ import ChevronRight from '../icons/ChevronRight.svelte'; const i18n = getContext('i18n'); + let promptsImportInputElement: HTMLInputElement; let importFiles = ''; let query = ''; - let promptsImportInputElement: HTMLInputElement; + + let prompts = []; let showDeleteConfirm = false; let deletePrompt = null; let filteredItems = []; - $: filteredItems = $prompts.filter((p) => query === '' || p.command.includes(query)); + $: filteredItems = prompts.filter((p) => query === '' || p.command.includes(query)); const shareHandler = async (prompt) => { toast.success($i18n.t('Redirecting you to OpenWebUI Community')); @@ -60,8 +68,17 @@ const deleteHandler = async (prompt) => { const command = prompt.command; await deletePromptByCommand(localStorage.token, command); - await prompts.set(await getPrompts(localStorage.token)); + await init(); }; + + const init = async () => { + prompts = await getPromptList(localStorage.token); + await _prompts.set(await getPrompts(localStorage.token)); + }; + + onMount(async () => { + await init(); + }); @@ -181,103 +198,98 @@ {/each} -
-
- { - console.log(importFiles); +{#if $user?.role === 'admin'} +
+
+ { + console.log(importFiles); - const reader = new FileReader(); - reader.onload = async (event) => { - const savedPrompts = JSON.parse(event.target.result); - console.log(savedPrompts); + const reader = new FileReader(); + reader.onload = async (event) => { + const savedPrompts = JSON.parse(event.target.result); + console.log(savedPrompts); - for (const prompt of savedPrompts) { - await createNewPrompt( - localStorage.token, - prompt.command.charAt(0) === '/' ? prompt.command.slice(1) : prompt.command, - prompt.title, - prompt.content - ).catch((error) => { - toast.error(error); - return null; - }); - } + for (const prompt of savedPrompts) { + await createNewPrompt( + localStorage.token, + prompt.command.charAt(0) === '/' ? prompt.command.slice(1) : prompt.command, + prompt.title, + prompt.content + ).catch((error) => { + toast.error(error); + return null; + }); + } - await prompts.set(await getPrompts(localStorage.token)); - }; + prompts = await getPromptList(localStorage.token); + await _prompts.set(await getPrompts(localStorage.token)); + }; - reader.readAsText(importFiles[0]); - }} - /> + reader.readAsText(importFiles[0]); + }} + /> - - - - - + + +
+ + + +
-
+{/if} {#if $config?.features.enable_community_sharing}
diff --git a/src/lib/components/workspace/Tools.svelte b/src/lib/components/workspace/Tools.svelte index 193addaa1..674911a53 100644 --- a/src/lib/components/workspace/Tools.svelte +++ b/src/lib/components/workspace/Tools.svelte @@ -4,7 +4,7 @@ const { saveAs } = fileSaver; import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, config, prompts, tools } from '$lib/stores'; + import { WEBUI_NAME, config, prompts, tools as _tools, user } from '$lib/stores'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; @@ -45,8 +45,9 @@ let showDeleteConfirm = false; + let tools = []; let filteredItems = []; - $: filteredItems = $tools.filter( + $: filteredItems = tools.filter( (t) => query === '' || t.name.toLowerCase().includes(query.toLowerCase()) || @@ -118,7 +119,7 @@ if (res) { toast.success($i18n.t('Tool deleted successfully')); - tools.set(await getTools(localStorage.token)); + _tools.set(await getTools(localStorage.token)); } }; @@ -324,80 +325,82 @@ {/each}
-
-
- { - console.log(importFiles); - showConfirm = true; - }} - /> +{#if $user?.role === 'admin'} +
+
+ { + console.log(importFiles); + showConfirm = true; + }} + /> - +
+ + + +
+ - + if (_tools) { + let blob = new Blob([JSON.stringify(_tools)], { + type: 'application/json' + }); + saveAs(blob, `tools-export-${Date.now()}.json`); + } + }} + > +
{$i18n.t('Export Tools')}
+ +
+ + + +
+ +
-
+{/if} {#if $config?.features.enable_community_sharing}
diff --git a/src/routes/(app)/workspace/prompts/+page.svelte b/src/routes/(app)/workspace/prompts/+page.svelte index fbe9918b7..48c6e65c6 100644 --- a/src/routes/(app)/workspace/prompts/+page.svelte +++ b/src/routes/(app)/workspace/prompts/+page.svelte @@ -1,19 +1,5 @@ -{#if $prompts !== null} - -{/if} + diff --git a/src/routes/(app)/workspace/tools/+page.svelte b/src/routes/(app)/workspace/tools/+page.svelte index d87cac6c8..86b1b2b7c 100644 --- a/src/routes/(app)/workspace/tools/+page.svelte +++ b/src/routes/(app)/workspace/tools/+page.svelte @@ -1,19 +1,7 @@ -{#if $tools !== null} - -{/if} +