feat: experimental SSO support for Google, Microsoft, and OIDC

This commit is contained in:
Jun Siang Cheah
2024-05-26 08:37:09 +01:00
parent a842d8d62b
commit 0210a105bf
10 changed files with 351 additions and 6 deletions

View File

@@ -0,0 +1,49 @@
"""Peewee migrations -- 011_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")

View File

@@ -1,6 +1,8 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from apps.webui.routers import (
auths,
users,
@@ -24,6 +26,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
AppConfig,
WEBUI_SECRET_KEY,
OAUTH_PROVIDERS,
)
app = FastAPI()
@@ -54,6 +58,12 @@ app.add_middleware(
allow_headers=["*"],
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])

View File

@@ -105,6 +105,7 @@ class AuthsTable:
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
log.info("insert_new_auth")
@@ -115,7 +116,9 @@ class AuthsTable:
)
result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
if result and user:
return user

View File

@@ -26,6 +26,8 @@ class User(Model):
api_key = CharField(null=True, unique=True)
oauth_sub = TextField(null=True, unique=True)
class Meta:
database = DB
@@ -43,6 +45,8 @@ class UserModel(BaseModel):
api_key: Optional[str] = None
oauth_sub: Optional[str] = None
####################
# Forms
@@ -73,6 +77,7 @@ class UsersTable:
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
user = UserModel(
**{
@@ -84,6 +89,7 @@ class UsersTable:
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
@@ -113,6 +119,13 @@ class UsersTable:
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))

View File

@@ -1,5 +1,7 @@
import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status
@@ -9,6 +11,7 @@ import re
import uuid
import csv
from starlette.responses import RedirectResponse
from apps.webui.models.auths import (
SigninForm,
@@ -33,7 +36,12 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
from config import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
)
router = APIRouter()
@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
@router.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@router.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, create a new user if signup is enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=request.app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)

View File

@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP = PersistentConfig(
"ENABLE_OAUTH_SIGNUP",
"oauth.enable_signup",
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
)
OAUTH_PROVIDERS = {}
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
OAUTH_PROVIDERS["google"] = {
"client_id": os.environ.get("GOOGLE_CLIENT_ID"),
"client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"),
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("MICROSOFT_CLIENT_ID")
and os.environ.get("MICROSOFT_CLIENT_SECRET")
and os.environ.get("MICROSOFT_CLIENT_TENANT_ID")
):
OAUTH_PROVIDERS["microsoft"] = {
"client_id": os.environ.get("MICROSOFT_CLIENT_ID"),
"client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"),
"server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration",
"scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("OPENID_CLIENT_ID")
and os.environ.get("OPENID_CLIENT_SECRET")
and os.environ.get("OPENID_PROVIDER_URL")
):
OAUTH_PROVIDERS["oidc"] = {
"client_id": os.environ.get("OPENID_CLIENT_ID"),
"client_secret": os.environ.get("OPENID_CLIENT_SECRET"),
"server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"),
"scope": os.environ.get("OPENID_SCOPE", "openid email profile"),
"name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"),
}
####################################
# Static DIR
####################################

View File

@@ -55,6 +55,7 @@ from config import (
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
AppConfig,
OAUTH_PROVIDERS,
)
from constants import ERROR_MESSAGES
@@ -364,6 +365,13 @@ async def get_app_config():
"default_locale": default_locale,
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"oauth": {
"providers": {
name: config.get("name", name)
for name, config in OAUTH_PROVIDERS.items()
}
},
}