enh: user permissions

This commit is contained in:
Timothy Jaeryang Baek 2024-11-16 21:07:56 -08:00
parent 79fbab7341
commit c0371f6525
3 changed files with 129 additions and 31 deletions

View File

@ -110,6 +110,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS app.state.config.BANNERS = WEBUI_BANNERS

View File

@ -40,10 +40,12 @@ from open_webui.utils.utils import (
get_password_hash, get_password_hash,
) )
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.access_control import get_permissions
from typing import Optional, List from typing import Optional, List
from ldap3 import Server, Connection, ALL, Tls
from ssl import CERT_REQUIRED, PROTOCOL_TLS from ssl import CERT_REQUIRED, PROTOCOL_TLS
from ldap3 import Server, Connection, ALL, Tls
from ldap3.utils.conv import escape_filter_chars from ldap3.utils.conv import escape_filter_chars
router = APIRouter() router = APIRouter()
@ -58,6 +60,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
class SessionUserResponse(Token, UserResponse): class SessionUserResponse(Token, UserResponse):
expires_at: Optional[int] = None expires_at: Optional[int] = None
permissions: Optional[dict] = None
@router.get("/", response_model=SessionUserResponse) @router.get("/", response_model=SessionUserResponse)
@ -90,6 +93,10 @@ async def get_session_user(
secure=WEBUI_SESSION_COOKIE_SECURE, secure=WEBUI_SESSION_COOKIE_SECURE,
) )
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -99,6 +106,7 @@ async def get_session_user(
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
"permissions": user_permissions,
} }
@ -163,40 +171,67 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL' LDAP_CIPHERS = (
request.app.state.config.LDAP_CIPHERS
if request.app.state.config.LDAP_CIPHERS
else "ALL"
)
if not ENABLE_LDAP: if not ENABLE_LDAP:
raise HTTPException(400, detail="LDAP authentication is not enabled") raise HTTPException(400, detail="LDAP authentication is not enabled")
try: try:
tls = Tls(validate=CERT_REQUIRED, version=PROTOCOL_TLS, ca_certs_file=LDAP_CA_CERT_FILE, ciphers=LDAP_CIPHERS) tls = Tls(
validate=CERT_REQUIRED,
version=PROTOCOL_TLS,
ca_certs_file=LDAP_CA_CERT_FILE,
ciphers=LDAP_CIPHERS,
)
except Exception as e: except Exception as e:
log.error(f"An error occurred on TLS: {str(e)}") log.error(f"An error occurred on TLS: {str(e)}")
raise HTTPException(400, detail=str(e)) raise HTTPException(400, detail=str(e))
try: try:
server = Server(host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, get_info=ALL, use_ssl=LDAP_USE_TLS, tls=tls) server = Server(
connection_app = Connection(server, LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind='NONE', authentication='SIMPLE') host=LDAP_SERVER_HOST,
port=LDAP_SERVER_PORT,
get_info=ALL,
use_ssl=LDAP_USE_TLS,
tls=tls,
)
connection_app = Connection(
server,
LDAP_APP_DN,
LDAP_APP_PASSWORD,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_app.bind(): if not connection_app.bind():
raise HTTPException(400, detail="Application account bind failed") raise HTTPException(400, detail="Application account bind failed")
search_success = connection_app.search( search_success = connection_app.search(
search_base=LDAP_SEARCH_BASE, search_base=LDAP_SEARCH_BASE,
search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})', search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
attributes=[f'{LDAP_ATTRIBUTE_FOR_USERNAME}', 'mail', 'cn'] attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
) )
if not search_success: if not search_success:
raise HTTPException(400, detail="User not found in the LDAP server") raise HTTPException(400, detail="User not found in the LDAP server")
entry = connection_app.entries[0] entry = connection_app.entries[0]
username = str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}']).lower() username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
mail = str(entry['mail']) mail = str(entry["mail"])
cn = str(entry['cn']) cn = str(entry["cn"])
user_dn = entry.entry_dn user_dn = entry.entry_dn
if username == form_data.user.lower(): if username == form_data.user.lower():
connection_user = Connection(server, user_dn, form_data.password, auto_bind='NONE', authentication='SIMPLE') connection_user = Connection(
server,
user_dn,
form_data.password,
auto_bind="NONE",
authentication="SIMPLE",
)
if not connection_user.bind(): if not connection_user.bind():
raise HTTPException(400, f"Authentication failed for {form_data.user}") raise HTTPException(400, f"Authentication failed for {form_data.user}")
@ -205,14 +240,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
try: try:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(mail, hashed, cn)
mail,
hashed,
cn
)
if not user: if not user:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
)
except HTTPException: except HTTPException:
raise raise
@ -224,7 +257,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), expires_delta=parse_duration(
request.app.state.config.JWT_EXPIRES_IN
),
) )
# Set the cookie token # Set the cookie token
@ -246,7 +281,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
else: else:
raise HTTPException(400, f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}") raise HTTPException(
400,
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
)
except Exception as e: except Exception as e:
raise HTTPException(400, detail=str(e)) raise HTTPException(400, detail=str(e))
@ -325,6 +363,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
secure=WEBUI_SESSION_COOKIE_SECURE, secure=WEBUI_SESSION_COOKIE_SECURE,
) )
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -334,6 +376,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
"permissions": user_permissions,
} }
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
@ -426,6 +469,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
}, },
) )
user_permissions = get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
@ -435,6 +482,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
"permissions": user_permissions,
} }
else: else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
@ -583,19 +631,18 @@ class LdapServerConfig(BaseModel):
label: str label: str
host: str host: str
port: Optional[int] = None port: Optional[int] = None
attribute_for_username: str = 'uid' attribute_for_username: str = "uid"
app_dn: str app_dn: str
app_dn_password: str app_dn_password: str
search_base: str search_base: str
search_filters: str = '' search_filters: str = ""
use_tls: bool = True use_tls: bool = True
certificate_path: Optional[str] = None certificate_path: Optional[str] = None
ciphers: Optional[str] = 'ALL' ciphers: Optional[str] = "ALL"
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig) @router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
async def get_ldap_server( async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
request: Request, user=Depends(get_admin_user)
):
return { return {
"label": request.app.state.config.LDAP_SERVER_LABEL, "label": request.app.state.config.LDAP_SERVER_LABEL,
"host": request.app.state.config.LDAP_SERVER_HOST, "host": request.app.state.config.LDAP_SERVER_HOST,
@ -607,26 +654,38 @@ async def get_ldap_server(
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS, "use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS "ciphers": request.app.state.config.LDAP_CIPHERS,
} }
@router.post("/admin/config/ldap/server") @router.post("/admin/config/ldap/server")
async def update_ldap_server( async def update_ldap_server(
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
): ):
required_fields = ['label', 'host', 'attribute_for_username', 'app_dn', 'app_dn_password', 'search_base'] required_fields = [
"label",
"host",
"attribute_for_username",
"app_dn",
"app_dn_password",
"search_base",
]
for key in required_fields: for key in required_fields:
value = getattr(form_data, key) value = getattr(form_data, key)
if not value: if not value:
raise HTTPException(400, detail=f"Required field {key} is empty") raise HTTPException(400, detail=f"Required field {key} is empty")
if form_data.use_tls and not form_data.certificate_path: if form_data.use_tls and not form_data.certificate_path:
raise HTTPException(400, detail="TLS is enabled but certificate file path is missing") raise HTTPException(
400, detail="TLS is enabled but certificate file path is missing"
)
request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_LABEL = form_data.label
request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_HOST = form_data.host
request.app.state.config.LDAP_SERVER_PORT = form_data.port request.app.state.config.LDAP_SERVER_PORT = form_data.port
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
form_data.attribute_for_username
)
request.app.state.config.LDAP_APP_DN = form_data.app_dn request.app.state.config.LDAP_APP_DN = form_data.app_dn
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
@ -646,18 +705,23 @@ async def update_ldap_server(
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
"use_tls": request.app.state.config.LDAP_USE_TLS, "use_tls": request.app.state.config.LDAP_USE_TLS,
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
"ciphers": request.app.state.config.LDAP_CIPHERS "ciphers": request.app.state.config.LDAP_CIPHERS,
} }
@router.get("/admin/config/ldap") @router.get("/admin/config/ldap")
async def get_ldap_config(request: Request, user=Depends(get_admin_user)): async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
class LdapConfigForm(BaseModel): class LdapConfigForm(BaseModel):
enable_ldap: Optional[bool] = None enable_ldap: Optional[bool] = None
@router.post("/admin/config/ldap") @router.post("/admin/config/ldap")
async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)): async def update_ldap_config(
request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}

View File

@ -2,6 +2,38 @@ from typing import Optional, Union, List, Dict
from open_webui.apps.webui.models.groups import Groups from open_webui.apps.webui.models.groups import Groups
def get_permissions(
user_id: str,
default_permissions: Dict[str, bool] = {},
) -> dict:
"""
Get all permissions for a user by combining the permissions of all groups the user is a member of.
If a permission is defined in multiple groups, the most permissive value is used.
"""
def merge_permissions(
permissions: Dict[str, bool], new_permissions: Dict[str, bool]
) -> Dict[str, bool]:
"""Merge two permission dictionaries, keeping the most permissive value."""
for key, value in new_permissions.items():
if key not in permissions:
permissions[key] = value
else:
permissions[key] = (
permissions[key] or value
) # Use the most permissive value
return permissions
user_groups = Groups.get_groups_by_member_id(user_id)
user_permissions = default_permissions.copy()
for group in user_groups:
user_permissions = merge_permissions(user_permissions, group.permissions)
return user_permissions
def has_permission( def has_permission(
user_id: str, user_id: str,
permission_key: str, permission_key: str,