wip: access control backend

This commit is contained in:
Timothy Jaeryang Baek
2024-11-15 01:29:07 -08:00
parent b80ec76435
commit 2ab5b2fd71
8 changed files with 282 additions and 52 deletions

View File

@@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
admin_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -119,6 +118,16 @@ class GroupTable:
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(Group.user_ids.contains([user_id]))
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:

View File

@@ -4,9 +4,20 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON
from open_webui.utils.utils import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -112,8 +123,14 @@ class ModelModel(BaseModel):
class ModelResponse(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -157,6 +174,24 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id != None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelModel]:
models = self.get_all_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:

View File

@@ -2,6 +2,8 @@ import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@@ -100,6 +102,64 @@ class PromptsTable:
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptModel]:
prompts = self.get_prompts()
groups = Groups.get_groups_by_member_id(user_id)
group_ids = [group.id for group in groups]
if permission == "write":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or (
prompt.access_control
and (
any(
group_id
in prompt.access_control.get(permission, {}).get(
"group_ids", []
)
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
elif permission == "read":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or prompt.access_control is None
or (
prompt.access_control
and (
any(
prompt.access_control.get(permission, {}).get(
"group_ids", []
)
in group_id
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:

View File

@@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import (
)
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
from open_webui.utils.utils import get_admin_user, get_verified_user
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
router = APIRouter()
###########################
# getModels
# GetModels
###########################
@router.get("/", response_model=list[ModelResponse])
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
if id:
model = Models.get_model_by_id(id)
if model:
return [model]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if user.role == "admin":
return Models.get_models()
else:
return Models.get_all_models()
return Models.get_models_by_user_id(user.id)
############################
# AddNewModel
# CreateNewModel
############################
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request,
@router.post("/create", response_model=Optional[ModelModel])
async def create_new_model(
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
if form_data.id in request.app.state.MODELS:
model = Models.get_model_by_id(form_data.id)
if model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
@@ -60,37 +57,49 @@ async def add_new_model(
)
###########################
# GetModelById
###########################
@router.get("/id/{id}", response_model=Optional[ModelResponse])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
):
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelById
############################
@router.post("/update", response_model=Optional[ModelModel])
@router.post("/id/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
user=Depends(get_verified_user),
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
model = Models.update_model_by_id(id, form_data)
return model
############################
@@ -98,7 +107,20 @@ async def update_model_by_id(
############################
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
if model.user_id != user.id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(id)
return result

View File

@@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
############################
@router.get("/permissions/user")
class WorkspacePermissions(BaseModel):
models: bool
knowledge: bool
prompts: bool
tools: bool
class ChatPermissions(BaseModel):
delete: bool
edit: bool
temporary: bool
class UserPermissions(BaseModel):
workspace: WorkspacePermissions
chat: ChatPermissions
@router.get("/permissions")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user")
@router.post("/permissions")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
):
request.app.state.config.USER_PERMISSIONS = form_data
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
return request.app.state.config.USER_PERMISSIONS