wip: access control backend

This commit is contained in:
Timothy Jaeryang Baek 2024-11-15 01:29:07 -08:00
parent b80ec76435
commit 2ab5b2fd71
8 changed files with 282 additions and 52 deletions

View File

@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
admin_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@ -119,6 +118,16 @@ class GroupTable:
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(Group.user_ids.contains([user_id]))
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:

View File

@ -4,9 +4,20 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON
from open_webui.utils.utils import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -112,8 +123,14 @@ class ModelModel(BaseModel):
class ModelResponse(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@ -157,6 +174,24 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id != None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelModel]:
models = self.get_all_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:

View File

@ -2,6 +2,8 @@ import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@ -100,6 +102,64 @@ class PromptsTable:
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptModel]:
prompts = self.get_prompts()
groups = Groups.get_groups_by_member_id(user_id)
group_ids = [group.id for group in groups]
if permission == "write":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or (
prompt.access_control
and (
any(
group_id
in prompt.access_control.get(permission, {}).get(
"group_ids", []
)
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
elif permission == "read":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or prompt.access_control is None
or (
prompt.access_control
and (
any(
prompt.access_control.get(permission, {}).get(
"group_ids", []
)
in group_id
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:

View File

@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import (
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
router = APIRouter()
###########################
# getModels
# GetModels
###########################
@router.get("/", response_model=list[ModelResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if id:
model = Models.get_model_by_id(id)
if model:
return [model]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if user.role == "admin":
return Models.get_models()
else:
return Models.get_all_models()
return Models.get_models_by_user_id(user.id)
############################
# AddNewModel
# CreateNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request,
@router.post("/create", response_model=Optional[ModelModel])
async def create_new_model(
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
if form_data.id in request.app.state.MODELS:
model = Models.get_model_by_id(form_data.id)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
@ -60,37 +57,49 @@ async def add_new_model(
)
###########################
# GetModelById
###########################
@router.get("/id/{id}", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
):
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/update", response_model=Optional[ModelModel])
@router.post("/id/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
model = Models.update_model_by_id(id, form_data)
return model
############################
@ -98,7 +107,20 @@ async def update_model_by_id(
############################
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if model.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(id)
return result

View File

@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
############################
@router.get("/permissions/user")
class WorkspacePermissions(BaseModel):
models: bool
knowledge: bool
prompts: bool
tools: bool
class ChatPermissions(BaseModel):
delete: bool
edit: bool
temporary: bool
class UserPermissions(BaseModel):
workspace: WorkspacePermissions
chat: ChatPermissions
@router.get("/permissions")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user")
@router.post("/permissions")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
):
request.app.state.config.USER_PERMISSIONS = form_data
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
return request.app.state.config.USER_PERMISSIONS

View File

@ -739,6 +739,26 @@ DEFAULT_USER_ROLE = PersistentConfig(
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS", "False").lower()
== "true"
)
USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS = (
os.environ.get("USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS", "False").lower() == "true"
)
USER_PERMISSIONS_CHAT_DELETE = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true"
)
@ -755,11 +775,17 @@ USER_PERMISSIONS = PersistentConfig(
"USER_PERMISSIONS",
"user.permissions",
{
"workspace": {
"models": USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS,
"knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ACCESS,
"prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ACCESS,
"tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ACCESS,
},
"chat": {
"deletion": USER_PERMISSIONS_CHAT_DELETE,
"editing": USER_PERMISSIONS_CHAT_EDIT,
"temporary": USER_PERMISSIONS_CHAT_TEMPORARY,
}
},
},
)

View File

@ -993,7 +993,7 @@ async def get_all_models():
models.append(
{
"id": custom_model.id,
"id": f"open-webui-{custom_model.id}",
"name": custom_model.name,
"object": "model",
"created": custom_model.created_at,

View File

@ -1,12 +1,18 @@
import logging
import uuid
from datetime import UTC, datetime, timedelta
from typing import Optional, Union
import jwt
from datetime import UTC, datetime, timedelta
from typing import Optional, Union, List, Dict
from open_webui.apps.webui.models.users import Users
from open_webui.apps.webui.models.groups import Groups
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SECRET_KEY
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from passlib.context import CryptContext
@ -147,3 +153,57 @@ def get_admin_user(user=Depends(get_current_user)):
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return user
def has_permission(
user_id: str,
permission_key: str,
default_permissions: Dict[str, bool] = {},
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
and falls back to default permissions if not found in any group.
Permission keys can be hierarchical and separated by dots ('.').
"""
def get_permission(permissions: Dict[str, bool], keys: List[str]) -> bool:
"""Traverse permissions dict using a list of keys (from dot-split permission_key)."""
for key in keys:
if key not in permissions:
return False # If any part of the hierarchy is missing, deny access
permissions = permissions[key] # Go one level deeper
return bool(permissions) # Return the boolean at the final level
permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id)
for group in user_groups:
group_permissions = group.permissions
if get_permission(group_permissions, permission_hierarchy):
return True
# Check default permissions afterwards if the group permissions don't allow it
return get_permission(default_permissions, permission_hierarchy)
def has_access(
user_id: str,
action: str = "write",
access_control: Optional[dict] = None,
) -> bool:
if access_control is None:
return action == "read"
user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(action, {})
permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", [])
return user_id in permitted_user_ids or any(
group_id in permitted_group_ids for group_id in user_group_ids
)