mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
feat: experimental SSO support for Google, Microsoft, and OIDC
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
"""Peewee migrations -- 011_add_user_oauth_sub.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
oauth_sub=pw.TextField(null=True, unique=True),
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("user", "oauth_sub")
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from apps.webui.routers import (
|
||||
auths,
|
||||
users,
|
||||
@@ -24,6 +26,8 @@ from config import (
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
JWT_EXPIRES_IN,
|
||||
AppConfig,
|
||||
WEBUI_SECRET_KEY,
|
||||
OAUTH_PROVIDERS,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
@@ -54,6 +58,12 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# SessionMiddleware is used by authlib for oauth
|
||||
if len(OAUTH_PROVIDERS) > 0:
|
||||
app.add_middleware(
|
||||
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
|
||||
)
|
||||
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
@@ -105,6 +105,7 @@ class AuthsTable:
|
||||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
log.info("insert_new_auth")
|
||||
|
||||
@@ -115,7 +116,9 @@ class AuthsTable:
|
||||
)
|
||||
result = Auth.create(**auth.model_dump())
|
||||
|
||||
user = Users.insert_new_user(id, name, email, profile_image_url, role)
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth_sub
|
||||
)
|
||||
|
||||
if result and user:
|
||||
return user
|
||||
|
||||
@@ -26,6 +26,8 @@ class User(Model):
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
|
||||
oauth_sub = TextField(null=True, unique=True)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
@@ -43,6 +45,8 @@ class UserModel(BaseModel):
|
||||
|
||||
api_key: Optional[str] = None
|
||||
|
||||
oauth_sub: Optional[str] = None
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
@@ -73,6 +77,7 @@ class UsersTable:
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
oauth_sub: Optional[str] = None,
|
||||
) -> Optional[UserModel]:
|
||||
user = UserModel(
|
||||
**{
|
||||
@@ -84,6 +89,7 @@ class UsersTable:
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"oauth_sub": oauth_sub,
|
||||
}
|
||||
)
|
||||
result = User.create(**user.model_dump())
|
||||
@@ -113,6 +119,13 @@ class UsersTable:
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
||||
try:
|
||||
user = User.get(User.oauth_sub == sub)
|
||||
return UserModel(**model_to_dict(user))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
|
||||
return [
|
||||
UserModel(**model_to_dict(user))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
from authlib.oidc.core import UserInfo
|
||||
from fastapi import Request, UploadFile, File
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
@@ -9,6 +11,7 @@ import re
|
||||
import uuid
|
||||
import csv
|
||||
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from apps.webui.models.auths import (
|
||||
SigninForm,
|
||||
@@ -33,7 +36,12 @@ from utils.utils import (
|
||||
from utils.misc import parse_duration, validate_email_format
|
||||
from utils.webhook import post_webhook
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
from config import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
|
||||
}
|
||||
else:
|
||||
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
|
||||
|
||||
############################
|
||||
# OAuth Login & Callback
|
||||
############################
|
||||
|
||||
oauth = OAuth()
|
||||
|
||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||
oauth.register(
|
||||
name=provider_name,
|
||||
client_id=provider_config["client_id"],
|
||||
client_secret=provider_config["client_secret"],
|
||||
server_metadata_url=provider_config["server_metadata_url"],
|
||||
client_kwargs={
|
||||
"scope": provider_config["scope"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}/login")
|
||||
async def oauth_login(provider: str, request: Request):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
redirect_uri = request.url_for("oauth_callback", provider=provider)
|
||||
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}/callback")
|
||||
async def oauth_callback(provider: str, request: Request):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
client = oauth.create_client(provider)
|
||||
token = await client.authorize_access_token(request)
|
||||
user_data: UserInfo = token["userinfo"]
|
||||
|
||||
sub = user_data.get("sub")
|
||||
if not sub:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||
|
||||
if not user:
|
||||
# If the user does not exist, create a new user if signup is enabled
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
user = Auths.insert_new_auth(
|
||||
email=user_data.get("email", "").lower(),
|
||||
password=get_password_hash(
|
||||
str(uuid.uuid4())
|
||||
), # Random password, not used
|
||||
name=user_data.get("name", "User"),
|
||||
profile_image_url=user_data.get("picture", "/user.png"),
|
||||
role=request.app.state.config.DEFAULT_USER_ROLE,
|
||||
oauth_sub=provider_sub,
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
"action": "signup",
|
||||
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
"user": user.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
jwt_token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
)
|
||||
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
|
||||
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
||||
)
|
||||
|
||||
####################################
|
||||
# OAuth config
|
||||
####################################
|
||||
|
||||
ENABLE_OAUTH_SIGNUP = PersistentConfig(
|
||||
"ENABLE_OAUTH_SIGNUP",
|
||||
"oauth.enable_signup",
|
||||
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
|
||||
)
|
||||
|
||||
OAUTH_PROVIDERS = {}
|
||||
|
||||
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
|
||||
OAUTH_PROVIDERS["google"] = {
|
||||
"client_id": os.environ.get("GOOGLE_CLIENT_ID"),
|
||||
"client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"),
|
||||
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
|
||||
}
|
||||
|
||||
if (
|
||||
os.environ.get("MICROSOFT_CLIENT_ID")
|
||||
and os.environ.get("MICROSOFT_CLIENT_SECRET")
|
||||
and os.environ.get("MICROSOFT_CLIENT_TENANT_ID")
|
||||
):
|
||||
OAUTH_PROVIDERS["microsoft"] = {
|
||||
"client_id": os.environ.get("MICROSOFT_CLIENT_ID"),
|
||||
"client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"),
|
||||
"server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration",
|
||||
"scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
|
||||
}
|
||||
|
||||
if (
|
||||
os.environ.get("OPENID_CLIENT_ID")
|
||||
and os.environ.get("OPENID_CLIENT_SECRET")
|
||||
and os.environ.get("OPENID_PROVIDER_URL")
|
||||
):
|
||||
OAUTH_PROVIDERS["oidc"] = {
|
||||
"client_id": os.environ.get("OPENID_CLIENT_ID"),
|
||||
"client_secret": os.environ.get("OPENID_CLIENT_SECRET"),
|
||||
"server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"),
|
||||
"scope": os.environ.get("OPENID_SCOPE", "openid email profile"),
|
||||
"name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"),
|
||||
}
|
||||
|
||||
|
||||
####################################
|
||||
# Static DIR
|
||||
####################################
|
||||
|
||||
@@ -55,6 +55,7 @@ from config import (
|
||||
WEBHOOK_URL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
AppConfig,
|
||||
OAUTH_PROVIDERS,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
@@ -364,6 +365,13 @@ async def get_app_config():
|
||||
"default_locale": default_locale,
|
||||
"default_models": webui_app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||
"oauth": {
|
||||
"providers": {
|
||||
name: config.get("name", name)
|
||||
for name, config in OAUTH_PROVIDERS.items()
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user