feat: experimental SSO support for Google, Microsoft, and OIDC

This commit is contained in:
Jun Siang Cheah 2024-05-26 08:37:09 +01:00
parent a842d8d62b
commit 0210a105bf
10 changed files with 351 additions and 6 deletions

View File

@ -0,0 +1,49 @@
"""Peewee migrations -- 011_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
oauth_sub=pw.TextField(null=True, unique=True),
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")

View File

@ -1,6 +1,8 @@
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from apps.webui.routers import ( from apps.webui.routers import (
auths, auths,
users, users,
@ -24,6 +26,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
AppConfig, AppConfig,
WEBUI_SECRET_KEY,
OAUTH_PROVIDERS,
) )
app = FastAPI() app = FastAPI()
@ -54,6 +58,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])

View File

@ -105,6 +105,7 @@ class AuthsTable:
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info("insert_new_auth") log.info("insert_new_auth")
@ -115,7 +116,9 @@ class AuthsTable:
) )
result = Auth.create(**auth.model_dump()) result = Auth.create(**auth.model_dump())
user = Users.insert_new_user(id, name, email, profile_image_url, role) user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
if result and user: if result and user:
return user return user

View File

@ -26,6 +26,8 @@ class User(Model):
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
oauth_sub = TextField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
@ -43,6 +45,8 @@ class UserModel(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
oauth_sub: Optional[str] = None
#################### ####################
# Forms # Forms
@ -73,6 +77,7 @@ class UsersTable:
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( user = UserModel(
**{ **{
@ -84,6 +89,7 @@ class UsersTable:
"last_active_at": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"oauth_sub": oauth_sub,
} }
) )
result = User.create(**user.model_dump()) result = User.create(**user.model_dump())
@ -113,6 +119,13 @@ class UsersTable:
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))

View File

@ -1,5 +1,7 @@
import logging import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import Request, UploadFile, File from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
@ -9,6 +11,7 @@ import re
import uuid import uuid
import csv import csv
from starlette.responses import RedirectResponse
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
@ -33,7 +36,12 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER from config import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
)
router = APIRouter() router = APIRouter()
@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
} }
else: else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
@router.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@router.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, create a new user if signup is enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=request.app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)

View File

@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
) )
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP = PersistentConfig(
"ENABLE_OAUTH_SIGNUP",
"oauth.enable_signup",
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
)
OAUTH_PROVIDERS = {}
if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
OAUTH_PROVIDERS["google"] = {
"client_id": os.environ.get("GOOGLE_CLIENT_ID"),
"client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"),
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("MICROSOFT_CLIENT_ID")
and os.environ.get("MICROSOFT_CLIENT_SECRET")
and os.environ.get("MICROSOFT_CLIENT_TENANT_ID")
):
OAUTH_PROVIDERS["microsoft"] = {
"client_id": os.environ.get("MICROSOFT_CLIENT_ID"),
"client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"),
"server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration",
"scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
}
if (
os.environ.get("OPENID_CLIENT_ID")
and os.environ.get("OPENID_CLIENT_SECRET")
and os.environ.get("OPENID_PROVIDER_URL")
):
OAUTH_PROVIDERS["oidc"] = {
"client_id": os.environ.get("OPENID_CLIENT_ID"),
"client_secret": os.environ.get("OPENID_CLIENT_SECRET"),
"server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"),
"scope": os.environ.get("OPENID_SCOPE", "openid email profile"),
"name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"),
}
#################################### ####################################
# Static DIR # Static DIR
#################################### ####################################

View File

@ -55,6 +55,7 @@ from config import (
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
AppConfig, AppConfig,
OAUTH_PROVIDERS,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -364,6 +365,13 @@ async def get_app_config():
"default_locale": default_locale, "default_locale": default_locale,
"default_models": webui_app.state.config.DEFAULT_MODELS, "default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"oauth": {
"providers": {
name: config.get("name", name)
for name, config in OAUTH_PROVIDERS.items()
}
},
} }

View File

@ -134,7 +134,12 @@ type Config = {
default_models?: string[]; default_models?: string[];
default_prompt_suggestions?: PromptSuggestion[]; default_prompt_suggestions?: PromptSuggestion[];
auth_trusted_header?: boolean; auth_trusted_header?: boolean;
model_config?: GlobalModelConfig; auth: boolean;
oauth: {
providers: {
[key: string]: string;
};
};
}; };
type PromptSuggestion = { type PromptSuggestion = {

View File

@ -2,6 +2,7 @@
import { onMount, tick, setContext } from 'svelte'; import { onMount, tick, setContext } from 'svelte';
import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores'; import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores';
import { Toaster, toast } from 'svelte-sonner'; import { Toaster, toast } from 'svelte-sonner';
import { getBackendConfig } from '$lib/apis'; import { getBackendConfig } from '$lib/apis';
@ -75,7 +76,11 @@
await goto('/auth'); await goto('/auth');
} }
} else { } else {
await goto('/auth'); // Don't redirect if we're already on the auth page
// Needed because we pass in tokens from OAuth logins via URL fragments
if ($page.url.pathname !== '/auth') {
await goto('/auth');
}
} }
} }
} else { } else {

View File

@ -1,12 +1,13 @@
<script> <script>
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { userSignIn, userSignUp } from '$lib/apis/auths'; import { getSessionUser, userSignIn, userSignUp } from '$lib/apis/auths';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user } from '$lib/stores'; import { WEBUI_NAME, config, user } from '$lib/stores';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { generateInitialsImage, canvasPixelTest } from '$lib/utils'; import { generateInitialsImage, canvasPixelTest } from '$lib/utils';
import { page } from '$app/stores';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -21,7 +22,9 @@
if (sessionUser) { if (sessionUser) {
console.log(sessionUser); console.log(sessionUser);
toast.success($i18n.t(`You're now logged in.`)); toast.success($i18n.t(`You're now logged in.`));
localStorage.token = sessionUser.token; if (sessionUser.token) {
localStorage.token = sessionUser.token;
}
await user.set(sessionUser); await user.set(sessionUser);
goto('/'); goto('/');
} }
@ -55,10 +58,35 @@
} }
}; };
const checkOauthCallback = async () => {
if (!$page.url.hash) {
return;
}
const hash = $page.url.hash.substring(1);
if (!hash) {
return;
}
const params = new URLSearchParams(hash);
const token = params.get('token');
if (!token) {
return;
}
const sessionUser = await getSessionUser(token).catch((error) => {
toast.error(error);
return null;
});
if (!sessionUser) {
return;
}
localStorage.token = token;
await setSessionUser(sessionUser);
};
onMount(async () => { onMount(async () => {
if ($user !== undefined) { if ($user !== undefined) {
await goto('/'); await goto('/');
} }
await checkOauthCallback();
loaded = true; loaded = true;
if (($config?.auth_trusted_header ?? false) || $config?.auth === false) { if (($config?.auth_trusted_header ?? false) || $config?.auth === false) {
await signInHandler(); await signInHandler();
@ -217,6 +245,97 @@
{/if} {/if}
</div> </div>
</form> </form>
{#if Object.keys($config?.oauth?.providers ?? {}).length > 0 }
<div class="inline-flex items-center justify-center w-full">
<hr class="w-64 h-px my-8 bg-gray-200 border-0 dark:bg-gray-700" />
<span
class="absolute px-3 font-medium text-gray-900 -translate-x-1/2 bg-white left-1/2 dark:text-white dark:bg-gray-950"
>{$i18n.t('or')}</span
>
</div>
<div class="flex flex-col space-y-2">
{#if $config?.oauth?.providers?.google }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
<path
fill="#EA4335"
d="M24 9.5c3.54 0 6.71 1.22 9.21 3.6l6.85-6.85C35.9 2.38 30.47 0 24 0 14.62 0 6.51 5.38 2.56 13.22l7.98 6.19C12.43 13.72 17.74 9.5 24 9.5z"
/><path
fill="#4285F4"
d="M46.98 24.55c0-1.57-.15-3.09-.38-4.55H24v9.02h12.94c-.58 2.96-2.26 5.48-4.78 7.18l7.73 6c4.51-4.18 7.09-10.36 7.09-17.65z"
/><path
fill="#FBBC05"
d="M10.53 28.59c-.48-1.45-.76-2.99-.76-4.59s.27-3.14.76-4.59l-7.98-6.19C.92 16.46 0 20.12 0 24c0 3.88.92 7.54 2.56 10.78l7.97-6.19z"
/><path
fill="#34A853"
d="M24 48c6.48 0 11.93-2.13 15.89-5.81l-7.73-6c-2.15 1.45-4.92 2.3-8.16 2.3-6.26 0-11.57-4.22-13.47-9.91l-7.98 6.19C6.51 42.62 14.62 48 24 48z"
/><path fill="none" d="M0 0h48v48H0z" />
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Google' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.microsoft }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
<rect x="1" y="1" width="9" height="9" fill="#f25022" /><rect
x="1"
y="11"
width="9"
height="9"
fill="#00a4ef"
/><rect x="11" y="1" width="9" height="9" fill="#7fba00" /><rect
x="11"
y="11"
width="9"
height="9"
fill="#ffb900"
/>
</svg>
<span>{$i18n.t('Continue with {{provider}}', { provider: 'Microsoft' })}</span>
</button>
{/if}
{#if $config?.oauth?.providers?.oidc }
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
fill="none"
viewBox="0 0 24 24"
stroke-width="1.5"
stroke="currentColor"
class="size-6 mr-3"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M15.75 5.25a3 3 0 0 1 3 3m3 0a6 6 0 0 1-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1 1 21.75 8.25Z"
/>
</svg>
<span
>{$i18n.t('Continue with {{provider}}', {
provider: $config?.oauth?.providers?.oidc ?? 'SSO'
})}</span
>
</button>
{/if}
</div>
{/if}
</div> </div>
{/if} {/if}
</div> </div>