Merge pull request #2574 from cheahjs/feat/oauth

feat: experimental SSO support for Google, Microsoft, and OIDC
This commit is contained in:
Timothy Jaeryang Baek
2024-06-24 19:05:58 -07:00
committed by GitHub
52 changed files with 633 additions and 13 deletions

View File

@@ -0,0 +1,49 @@
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")

View File

@@ -2,6 +2,8 @@ from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from apps.webui.routers import (
auths,
users,

View File

@@ -105,6 +105,7 @@ class AuthsTable:
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
log.info("insert_new_auth")
@@ -115,7 +116,9 @@ class AuthsTable:
)
result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
if result and user:
return user

View File

@@ -28,6 +28,8 @@ class User(Model):
settings = JSONField(null=True)
info = JSONField(null=True)
oauth_sub = TextField(null=True, unique=True)
class Meta:
database = DB
@@ -53,6 +55,8 @@ class UserModel(BaseModel):
settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None
####################
# Forms
@@ -83,6 +87,7 @@ class UsersTable:
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
user = UserModel(
**{
@@ -94,6 +99,7 @@ class UsersTable:
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
@@ -123,6 +129,13 @@ class UsersTable:
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
@@ -174,6 +187,18 @@ class UsersTable:
except:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)

View File

@@ -10,7 +10,6 @@ import re
import uuid
import csv
from apps.webui.models.auths import (
SigninForm,
SignupForm,

View File

@@ -305,6 +305,135 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP = PersistentConfig(
"ENABLE_OAUTH_SIGNUP",
"oauth.enable_signup",
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
)
OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
"OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
"oauth.merge_accounts_by_email",
os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
)
OAUTH_PROVIDERS = {}
GOOGLE_CLIENT_ID = PersistentConfig(
"GOOGLE_CLIENT_ID",
"oauth.google.client_id",
os.environ.get("GOOGLE_CLIENT_ID", ""),
)
GOOGLE_CLIENT_SECRET = PersistentConfig(
"GOOGLE_CLIENT_SECRET",
"oauth.google.client_secret",
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
)
GOOGLE_OAUTH_SCOPE = PersistentConfig(
"GOOGLE_OAUTH_SCOPE",
"oauth.google.scope",
os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
)
MICROSOFT_CLIENT_ID = PersistentConfig(
"MICROSOFT_CLIENT_ID",
"oauth.microsoft.client_id",
os.environ.get("MICROSOFT_CLIENT_ID", ""),
)
MICROSOFT_CLIENT_SECRET = PersistentConfig(
"MICROSOFT_CLIENT_SECRET",
"oauth.microsoft.client_secret",
os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
)
MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
"MICROSOFT_CLIENT_TENANT_ID",
"oauth.microsoft.tenant_id",
os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
)
MICROSOFT_OAUTH_SCOPE = PersistentConfig(
"MICROSOFT_OAUTH_SCOPE",
"oauth.microsoft.scope",
os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
)
OAUTH_CLIENT_ID = PersistentConfig(
"OAUTH_CLIENT_ID",
"oauth.oidc.client_id",
os.environ.get("OAUTH_CLIENT_ID", ""),
)
OAUTH_CLIENT_SECRET = PersistentConfig(
"OAUTH_CLIENT_SECRET",
"oauth.oidc.client_secret",
os.environ.get("OAUTH_CLIENT_SECRET", ""),
)
OPENID_PROVIDER_URL = PersistentConfig(
"OPENID_PROVIDER_URL",
"oauth.oidc.provider_url",
os.environ.get("OPENID_PROVIDER_URL", ""),
)
OAUTH_SCOPES = PersistentConfig(
"OAUTH_SCOPES",
"oauth.oidc.scopes",
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
OAUTH_PROVIDER_NAME = PersistentConfig(
"OAUTH_PROVIDER_NAME",
"oauth.oidc.provider_name",
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
OAUTH_PROVIDERS["google"] = {
"client_id": GOOGLE_CLIENT_ID.value,
"client_secret": GOOGLE_CLIENT_SECRET.value,
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": GOOGLE_OAUTH_SCOPE.value,
}
if (
MICROSOFT_CLIENT_ID.value
and MICROSOFT_CLIENT_SECRET.value
and MICROSOFT_CLIENT_TENANT_ID.value
):
OAUTH_PROVIDERS["microsoft"] = {
"client_id": MICROSOFT_CLIENT_ID.value,
"client_secret": MICROSOFT_CLIENT_SECRET.value,
"server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
"scope": MICROSOFT_OAUTH_SCOPE.value,
}
if (
OAUTH_CLIENT_ID.value
and OAUTH_CLIENT_SECRET.value
and OPENID_PROVIDER_URL.value
):
OAUTH_PROVIDERS["oidc"] = {
"client_id": OAUTH_CLIENT_ID.value,
"client_secret": OAUTH_CLIENT_SECRET.value,
"server_metadata_url": OPENID_PROVIDER_URL.value,
"scope": OAUTH_SCOPES.value,
"name": OAUTH_PROVIDER_NAME.value,
}
load_oauth_providers()
####################################
# Static DIR
####################################
@@ -733,6 +862,16 @@ WEBUI_SECRET_KEY = os.environ.get(
), # DEPRECATED: remove at next major version
)
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)

View File

@@ -1,4 +1,9 @@
import base64
import uuid
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
@@ -24,7 +29,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
@@ -53,9 +59,11 @@ from apps.webui.main import (
from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
@@ -64,6 +72,8 @@ from utils.utils import (
get_verified_user,
get_current_user,
get_http_authorization_cred,
get_password_hash,
create_token,
)
from utils.task import (
title_generation_template,
@@ -74,6 +84,7 @@ from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
@@ -106,9 +117,16 @@ from config import (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
)
from constants import ERROR_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
if SAFE_MODE:
print("SAFE MODE ENABLED")
@@ -1725,6 +1743,12 @@ async def get_app_config():
"engine": audio_app.state.config.STT_ENGINE,
},
},
"oauth": {
"providers": {
name: config.get("name", name)
for name, config in OAUTH_PROVIDERS.items()
}
},
}
@@ -1806,6 +1830,154 @@ async def get_app_latest_release_version():
)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware,
secret_key=WEBUI_SECRET_KEY,
session_cookie="oui-session",
same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
https_only=WEBUI_SESSION_COOKIE_SECURE,
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows:
# 1. Attempt to find a user with matching subject ID, tied to the provider
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email = user_data.get("email", "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=picture_url,
role=webui_app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json")
async def get_manifest_json():
return {

View File

@@ -58,6 +58,7 @@ rank-bm25==0.2.2
faster-whisper==1.0.2
PyJWT[crypto]==2.8.0
authlib==1.3.0
black==24.4.2
langfuse==2.33.0