mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
enh: access control
This commit is contained in:
@@ -7,6 +7,8 @@ from open_webui.apps.webui.models.groups import Groups
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
####################
|
||||
@@ -107,58 +109,12 @@ class PromptsTable:
|
||||
) -> list[PromptModel]:
|
||||
prompts = self.get_prompts()
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user_id)
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
if permission == "write":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
group_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
elif permission == "read":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or prompt.access_control is None
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
in group_id
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or has_access(user_id, permission, prompt.access_control)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
|
||||
@@ -8,6 +8,9 @@ from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
@@ -133,6 +136,18 @@ class ToolsTable:
|
||||
with get_db() as db:
|
||||
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ToolModel]:
|
||||
tools = self.get_tools()
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user_id
|
||||
or has_access(tool.access_control, user_id, permission)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
Reference in New Issue
Block a user