diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 59557349e..fcfccaedf 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -349,6 +349,10 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None +) + BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 86eec77e9..3ad88bc11 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -159,8 +159,8 @@ class AuthsTable: except Exception: return False - def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: - log.info(f"authenticate_user_by_trusted_header: {email}") + def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_email: {email}") try: with get_db() as db: auth = db.query(Auth).filter_by(email=email, active=True).first() diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 763340fbc..df79284cf 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -207,5 +207,43 @@ class GroupTable: except Exception: return False + def sync_user_groups_by_group_names( + self, user_id: str, group_names: list[str] + ) -> bool: + with get_db() as db: + try: + groups = db.query(Group).filter(Group.name.in_(group_names)).all() + group_ids = [group.id for group in groups] + + # Remove user from groups not in the new list + existing_groups = self.get_groups_by_member_id(user_id) + + for group in existing_groups: + if group.id not in group_ids: + group.user_ids.remove(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + # Add user to new groups + for group in groups: + if user_id not in group.user_ids: + group.user_ids.append(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + db.commit() + return True + except Exception as e: + log.exception(e) + return False + Groups = GroupTable() diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 793bdfd30..06e506228 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -19,12 +19,14 @@ from open_webui.models.auths import ( UserResponse, ) from open_webui.models.users import Users +from open_webui.models.groups import Groups from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, @@ -299,7 +301,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_trusted_header(email) + user = Auths.authenticate_user_by_email(email) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -363,21 +365,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) - trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() - trusted_name = trusted_email + email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() + name = email + if WEBUI_AUTH_TRUSTED_NAME_HEADER: - trusted_name = request.headers.get( - WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email - ) - if not Users.get_user_by_email(trusted_email.lower()): + name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) + + if not Users.get_user_by_email(email.lower()): await signup( request, response, - SignupForm( - email=trusted_email, password=str(uuid.uuid4()), name=trusted_name - ), + SignupForm(email=email, password=str(uuid.uuid4()), name=name), ) - user = Auths.authenticate_user_by_trusted_header(trusted_email) + + user = Auths.authenticate_user_by_email(email) + if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": + group_names = request.headers.get( + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" + ).split(",") + group_names = [name.strip() for name in group_names if name.strip()] + + if group_names: + Groups.sync_user_groups_by_group_names(user.id, group_names) + elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin"