mirror of
https://github.com/open-webui/open-webui
synced 2024-11-22 08:07:55 +00:00
enh: user permissions
This commit is contained in:
parent
79fbab7341
commit
c0371f6525
@ -110,6 +110,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
|||||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||||
|
|
||||||
|
|
||||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
app.state.config.BANNERS = WEBUI_BANNERS
|
app.state.config.BANNERS = WEBUI_BANNERS
|
||||||
|
@ -40,10 +40,12 @@ from open_webui.utils.utils import (
|
|||||||
get_password_hash,
|
get_password_hash,
|
||||||
)
|
)
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
from open_webui.utils.access_control import get_permissions
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from ldap3 import Server, Connection, ALL, Tls
|
|
||||||
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
from ssl import CERT_REQUIRED, PROTOCOL_TLS
|
||||||
|
from ldap3 import Server, Connection, ALL, Tls
|
||||||
from ldap3.utils.conv import escape_filter_chars
|
from ldap3.utils.conv import escape_filter_chars
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -58,6 +60,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||||||
|
|
||||||
class SessionUserResponse(Token, UserResponse):
|
class SessionUserResponse(Token, UserResponse):
|
||||||
expires_at: Optional[int] = None
|
expires_at: Optional[int] = None
|
||||||
|
permissions: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=SessionUserResponse)
|
@router.get("/", response_model=SessionUserResponse)
|
||||||
@ -90,6 +93,10 @@ async def get_session_user(
|
|||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_permissions = get_permissions(
|
||||||
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
@ -99,6 +106,7 @@ async def get_session_user(
|
|||||||
"name": user.name,
|
"name": user.name,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
"profile_image_url": user.profile_image_url,
|
"profile_image_url": user.profile_image_url,
|
||||||
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -163,40 +171,67 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||||||
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD
|
||||||
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS
|
||||||
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE
|
||||||
LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL'
|
LDAP_CIPHERS = (
|
||||||
|
request.app.state.config.LDAP_CIPHERS
|
||||||
|
if request.app.state.config.LDAP_CIPHERS
|
||||||
|
else "ALL"
|
||||||
|
)
|
||||||
|
|
||||||
if not ENABLE_LDAP:
|
if not ENABLE_LDAP:
|
||||||
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
raise HTTPException(400, detail="LDAP authentication is not enabled")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tls = Tls(validate=CERT_REQUIRED, version=PROTOCOL_TLS, ca_certs_file=LDAP_CA_CERT_FILE, ciphers=LDAP_CIPHERS)
|
tls = Tls(
|
||||||
|
validate=CERT_REQUIRED,
|
||||||
|
version=PROTOCOL_TLS,
|
||||||
|
ca_certs_file=LDAP_CA_CERT_FILE,
|
||||||
|
ciphers=LDAP_CIPHERS,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"An error occurred on TLS: {str(e)}")
|
log.error(f"An error occurred on TLS: {str(e)}")
|
||||||
raise HTTPException(400, detail=str(e))
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server = Server(host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, get_info=ALL, use_ssl=LDAP_USE_TLS, tls=tls)
|
server = Server(
|
||||||
connection_app = Connection(server, LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind='NONE', authentication='SIMPLE')
|
host=LDAP_SERVER_HOST,
|
||||||
|
port=LDAP_SERVER_PORT,
|
||||||
|
get_info=ALL,
|
||||||
|
use_ssl=LDAP_USE_TLS,
|
||||||
|
tls=tls,
|
||||||
|
)
|
||||||
|
connection_app = Connection(
|
||||||
|
server,
|
||||||
|
LDAP_APP_DN,
|
||||||
|
LDAP_APP_PASSWORD,
|
||||||
|
auto_bind="NONE",
|
||||||
|
authentication="SIMPLE",
|
||||||
|
)
|
||||||
if not connection_app.bind():
|
if not connection_app.bind():
|
||||||
raise HTTPException(400, detail="Application account bind failed")
|
raise HTTPException(400, detail="Application account bind failed")
|
||||||
|
|
||||||
search_success = connection_app.search(
|
search_success = connection_app.search(
|
||||||
search_base=LDAP_SEARCH_BASE,
|
search_base=LDAP_SEARCH_BASE,
|
||||||
search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})',
|
search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})",
|
||||||
attributes=[f'{LDAP_ATTRIBUTE_FOR_USERNAME}', 'mail', 'cn']
|
attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if not search_success:
|
if not search_success:
|
||||||
raise HTTPException(400, detail="User not found in the LDAP server")
|
raise HTTPException(400, detail="User not found in the LDAP server")
|
||||||
|
|
||||||
entry = connection_app.entries[0]
|
entry = connection_app.entries[0]
|
||||||
username = str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}']).lower()
|
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
|
||||||
mail = str(entry['mail'])
|
mail = str(entry["mail"])
|
||||||
cn = str(entry['cn'])
|
cn = str(entry["cn"])
|
||||||
user_dn = entry.entry_dn
|
user_dn = entry.entry_dn
|
||||||
|
|
||||||
if username == form_data.user.lower():
|
if username == form_data.user.lower():
|
||||||
connection_user = Connection(server, user_dn, form_data.password, auto_bind='NONE', authentication='SIMPLE')
|
connection_user = Connection(
|
||||||
|
server,
|
||||||
|
user_dn,
|
||||||
|
form_data.password,
|
||||||
|
auto_bind="NONE",
|
||||||
|
authentication="SIMPLE",
|
||||||
|
)
|
||||||
if not connection_user.bind():
|
if not connection_user.bind():
|
||||||
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
raise HTTPException(400, f"Authentication failed for {form_data.user}")
|
||||||
|
|
||||||
@ -205,14 +240,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(mail, hashed, cn)
|
||||||
mail,
|
|
||||||
hashed,
|
|
||||||
cn
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
raise HTTPException(
|
||||||
|
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@ -224,7 +257,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||||||
if user:
|
if user:
|
||||||
token = create_token(
|
token = create_token(
|
||||||
data={"id": user.id},
|
data={"id": user.id},
|
||||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
expires_delta=parse_duration(
|
||||||
|
request.app.state.config.JWT_EXPIRES_IN
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the cookie token
|
# Set the cookie token
|
||||||
@ -246,7 +281,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}")
|
raise HTTPException(
|
||||||
|
400,
|
||||||
|
f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(400, detail=str(e))
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
@ -325,6 +363,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
secure=WEBUI_SESSION_COOKIE_SECURE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_permissions = get_permissions(
|
||||||
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
@ -334,6 +376,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||||||
"name": user.name,
|
"name": user.name,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
"profile_image_url": user.profile_image_url,
|
"profile_image_url": user.profile_image_url,
|
||||||
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
@ -426,6 +469,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
user_permissions = get_permissions(
|
||||||
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
@ -435,6 +482,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||||||
"name": user.name,
|
"name": user.name,
|
||||||
"role": user.role,
|
"role": user.role,
|
||||||
"profile_image_url": user.profile_image_url,
|
"profile_image_url": user.profile_image_url,
|
||||||
|
"permissions": user_permissions,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
||||||
@ -583,19 +631,18 @@ class LdapServerConfig(BaseModel):
|
|||||||
label: str
|
label: str
|
||||||
host: str
|
host: str
|
||||||
port: Optional[int] = None
|
port: Optional[int] = None
|
||||||
attribute_for_username: str = 'uid'
|
attribute_for_username: str = "uid"
|
||||||
app_dn: str
|
app_dn: str
|
||||||
app_dn_password: str
|
app_dn_password: str
|
||||||
search_base: str
|
search_base: str
|
||||||
search_filters: str = ''
|
search_filters: str = ""
|
||||||
use_tls: bool = True
|
use_tls: bool = True
|
||||||
certificate_path: Optional[str] = None
|
certificate_path: Optional[str] = None
|
||||||
ciphers: Optional[str] = 'ALL'
|
ciphers: Optional[str] = "ALL"
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
|
@router.get("/admin/config/ldap/server", response_model=LdapServerConfig)
|
||||||
async def get_ldap_server(
|
async def get_ldap_server(request: Request, user=Depends(get_admin_user)):
|
||||||
request: Request, user=Depends(get_admin_user)
|
|
||||||
):
|
|
||||||
return {
|
return {
|
||||||
"label": request.app.state.config.LDAP_SERVER_LABEL,
|
"label": request.app.state.config.LDAP_SERVER_LABEL,
|
||||||
"host": request.app.state.config.LDAP_SERVER_HOST,
|
"host": request.app.state.config.LDAP_SERVER_HOST,
|
||||||
@ -607,26 +654,38 @@ async def get_ldap_server(
|
|||||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||||
"ciphers": request.app.state.config.LDAP_CIPHERS
|
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/admin/config/ldap/server")
|
@router.post("/admin/config/ldap/server")
|
||||||
async def update_ldap_server(
|
async def update_ldap_server(
|
||||||
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
|
request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
required_fields = ['label', 'host', 'attribute_for_username', 'app_dn', 'app_dn_password', 'search_base']
|
required_fields = [
|
||||||
|
"label",
|
||||||
|
"host",
|
||||||
|
"attribute_for_username",
|
||||||
|
"app_dn",
|
||||||
|
"app_dn_password",
|
||||||
|
"search_base",
|
||||||
|
]
|
||||||
for key in required_fields:
|
for key in required_fields:
|
||||||
value = getattr(form_data, key)
|
value = getattr(form_data, key)
|
||||||
if not value:
|
if not value:
|
||||||
raise HTTPException(400, detail=f"Required field {key} is empty")
|
raise HTTPException(400, detail=f"Required field {key} is empty")
|
||||||
|
|
||||||
if form_data.use_tls and not form_data.certificate_path:
|
if form_data.use_tls and not form_data.certificate_path:
|
||||||
raise HTTPException(400, detail="TLS is enabled but certificate file path is missing")
|
raise HTTPException(
|
||||||
|
400, detail="TLS is enabled but certificate file path is missing"
|
||||||
|
)
|
||||||
|
|
||||||
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
|
request.app.state.config.LDAP_SERVER_LABEL = form_data.label
|
||||||
request.app.state.config.LDAP_SERVER_HOST = form_data.host
|
request.app.state.config.LDAP_SERVER_HOST = form_data.host
|
||||||
request.app.state.config.LDAP_SERVER_PORT = form_data.port
|
request.app.state.config.LDAP_SERVER_PORT = form_data.port
|
||||||
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username
|
request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = (
|
||||||
|
form_data.attribute_for_username
|
||||||
|
)
|
||||||
request.app.state.config.LDAP_APP_DN = form_data.app_dn
|
request.app.state.config.LDAP_APP_DN = form_data.app_dn
|
||||||
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
|
request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password
|
||||||
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
|
request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base
|
||||||
@ -646,18 +705,23 @@ async def update_ldap_server(
|
|||||||
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
"search_filters": request.app.state.config.LDAP_SEARCH_FILTERS,
|
||||||
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
"use_tls": request.app.state.config.LDAP_USE_TLS,
|
||||||
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
"certificate_path": request.app.state.config.LDAP_CA_CERT_FILE,
|
||||||
"ciphers": request.app.state.config.LDAP_CIPHERS
|
"ciphers": request.app.state.config.LDAP_CIPHERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/config/ldap")
|
@router.get("/admin/config/ldap")
|
||||||
async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
|
async def get_ldap_config(request: Request, user=Depends(get_admin_user)):
|
||||||
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
||||||
|
|
||||||
|
|
||||||
class LdapConfigForm(BaseModel):
|
class LdapConfigForm(BaseModel):
|
||||||
enable_ldap: Optional[bool] = None
|
enable_ldap: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/admin/config/ldap")
|
@router.post("/admin/config/ldap")
|
||||||
async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)):
|
async def update_ldap_config(
|
||||||
|
request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
|
request.app.state.config.ENABLE_LDAP = form_data.enable_ldap
|
||||||
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP}
|
||||||
|
|
||||||
|
@ -2,6 +2,38 @@ from typing import Optional, Union, List, Dict
|
|||||||
from open_webui.apps.webui.models.groups import Groups
|
from open_webui.apps.webui.models.groups import Groups
|
||||||
|
|
||||||
|
|
||||||
|
def get_permissions(
|
||||||
|
user_id: str,
|
||||||
|
default_permissions: Dict[str, bool] = {},
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
||||||
|
If a permission is defined in multiple groups, the most permissive value is used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def merge_permissions(
|
||||||
|
permissions: Dict[str, bool], new_permissions: Dict[str, bool]
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
"""Merge two permission dictionaries, keeping the most permissive value."""
|
||||||
|
for key, value in new_permissions.items():
|
||||||
|
if key not in permissions:
|
||||||
|
permissions[key] = value
|
||||||
|
else:
|
||||||
|
permissions[key] = (
|
||||||
|
permissions[key] or value
|
||||||
|
) # Use the most permissive value
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
|
||||||
|
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||||
|
user_permissions = default_permissions.copy()
|
||||||
|
|
||||||
|
for group in user_groups:
|
||||||
|
user_permissions = merge_permissions(user_permissions, group.permissions)
|
||||||
|
|
||||||
|
return user_permissions
|
||||||
|
|
||||||
|
|
||||||
def has_permission(
|
def has_permission(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission_key: str,
|
permission_key: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user