mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
wip: access control backend
This commit is contained in:
@@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
|
||||
permissions: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
user_ids: list[str] = []
|
||||
admin_ids: list[str] = []
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
@@ -119,6 +118,16 @@ class GroupTable:
|
||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group)
|
||||
.filter(Group.user_ids.contains([user_id]))
|
||||
.order_by(Group.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -4,9 +4,20 @@ from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.apps.webui.models.groups import Groups
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from sqlalchemy import or_, and_, func
|
||||
from sqlalchemy.dialects import postgresql, sqlite
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON
|
||||
|
||||
|
||||
from open_webui.utils.utils import has_access
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
@@ -112,8 +123,14 @@ class ModelModel(BaseModel):
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
base_model_id: Optional[str] = None
|
||||
|
||||
name: str
|
||||
params: ModelParams
|
||||
meta: ModelMeta
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
@@ -157,6 +174,24 @@ class ModelsTable:
|
||||
with get_db() as db:
|
||||
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
|
||||
|
||||
def get_models(self) -> list[ModelModel]:
|
||||
with get_db() as db:
|
||||
return [
|
||||
ModelModel.model_validate(model)
|
||||
for model in db.query(Model).filter(Model.base_model_id != None).all()
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ModelModel]:
|
||||
models = self.get_all_models()
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
if model.user_id == user_id
|
||||
or has_access(user_id, permission, model.access_control)
|
||||
]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
|
||||
@@ -2,6 +2,8 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.groups import Groups
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
|
||||
@@ -100,6 +102,64 @@ class PromptsTable:
|
||||
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
|
||||
]
|
||||
|
||||
def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[PromptModel]:
|
||||
prompts = self.get_prompts()
|
||||
|
||||
groups = Groups.get_groups_by_member_id(user_id)
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
if permission == "write":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
group_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
elif permission == "read":
|
||||
return [
|
||||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or prompt.access_control is None
|
||||
or (
|
||||
prompt.access_control
|
||||
and (
|
||||
any(
|
||||
prompt.access_control.get(permission, {}).get(
|
||||
"group_ids", []
|
||||
)
|
||||
in group_id
|
||||
for group_id in group_ids
|
||||
)
|
||||
or (
|
||||
user_id
|
||||
in prompt.access_control.get(permission, {}).get(
|
||||
"user_ids", []
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
|
||||
@@ -8,49 +8,46 @@ from open_webui.apps.webui.models.models import (
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
###########################
|
||||
# getModels
|
||||
# GetModels
|
||||
###########################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ModelResponse])
|
||||
async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)):
|
||||
if id:
|
||||
model = Models.get_model_by_id(id)
|
||||
if model:
|
||||
return [model]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
if user.role == "admin":
|
||||
return Models.get_models()
|
||||
else:
|
||||
return Models.get_all_models()
|
||||
return Models.get_models_by_user_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
# AddNewModel
|
||||
# CreateNewModel
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/add", response_model=Optional[ModelModel])
|
||||
async def add_new_model(
|
||||
request: Request,
|
||||
@router.post("/create", response_model=Optional[ModelModel])
|
||||
async def create_new_model(
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_admin_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if form_data.id in request.app.state.MODELS:
|
||||
|
||||
model = Models.get_model_by_id(form_data.id)
|
||||
if model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
||||
)
|
||||
|
||||
else:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
@@ -60,37 +57,49 @@ async def add_new_model(
|
||||
)
|
||||
|
||||
|
||||
###########################
|
||||
# GetModelById
|
||||
###########################
|
||||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[ModelResponse])
|
||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
if model:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "read", model.access_control)
|
||||
):
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateModelById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/update", response_model=Optional[ModelModel])
|
||||
@router.post("/id/{id}/update", response_model=Optional[ModelModel])
|
||||
async def update_model_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_admin_user),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
model = Models.get_model_by_id(id)
|
||||
if model:
|
||||
model = Models.update_model_by_id(id, form_data)
|
||||
return model
|
||||
else:
|
||||
if form_data.id in request.app.state.MODELS:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
model = Models.update_model_by_id(id, form_data)
|
||||
return model
|
||||
|
||||
|
||||
############################
|
||||
@@ -98,7 +107,20 @@ async def update_model_by_id(
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/delete", response_model=bool)
|
||||
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if model.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
result = Models.delete_model_by_id(id)
|
||||
return result
|
||||
|
||||
@@ -36,16 +36,34 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/permissions/user")
|
||||
class WorkspacePermissions(BaseModel):
|
||||
models: bool
|
||||
knowledge: bool
|
||||
prompts: bool
|
||||
tools: bool
|
||||
|
||||
|
||||
class ChatPermissions(BaseModel):
|
||||
delete: bool
|
||||
edit: bool
|
||||
temporary: bool
|
||||
|
||||
|
||||
class UserPermissions(BaseModel):
|
||||
workspace: WorkspacePermissions
|
||||
chat: ChatPermissions
|
||||
|
||||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
|
||||
return request.app.state.config.USER_PERMISSIONS
|
||||
|
||||
|
||||
@router.post("/permissions/user")
|
||||
@router.post("/permissions")
|
||||
async def update_user_permissions(
|
||||
request: Request, form_data: dict, user=Depends(get_admin_user)
|
||||
request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.USER_PERMISSIONS = form_data
|
||||
request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
|
||||
return request.app.state.config.USER_PERMISSIONS
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user