From 0210a105bfb4527f31e2bf1fb5086b0341ffe156 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah <git@jscheah.me> Date: Sun, 26 May 2024 08:37:09 +0100 Subject: [PATCH] feat: experimental SSO support for Google, Microsoft, and OIDC --- .../migrations/011_add_user_oauth_sub.py | 49 +++++++ backend/apps/webui/main.py | 10 ++ backend/apps/webui/models/auths.py | 5 +- backend/apps/webui/models/users.py | 13 ++ backend/apps/webui/routers/auths.py | 89 ++++++++++++- backend/config.py | 46 +++++++ backend/main.py | 8 ++ src/lib/stores/index.ts | 7 +- src/routes/+layout.svelte | 7 +- src/routes/auth/+page.svelte | 123 +++++++++++++++++- 10 files changed, 351 insertions(+), 6 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py diff --git a/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py new file mode 100644 index 000000000..70dfeccf0 --- /dev/null +++ b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py @@ -0,0 +1,49 @@ +"""Peewee migrations -- 011_add_user_oauth_sub.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index d736cef9a..6da18f9f0 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,6 +1,8 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.sessions import SessionMiddleware + from apps.webui.routers import ( auths, users, @@ -24,6 +26,8 @@ from config import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, AppConfig, + WEBUI_SECRET_KEY, + OAUTH_PROVIDERS, ) app = FastAPI() @@ -54,6 +58,12 @@ app.add_middleware( allow_headers=["*"], ) +# SessionMiddleware is used by authlib for oauth +if len(OAUTH_PROVIDERS) > 0: + app.add_middleware( + SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session" + ) + app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index e3b659e43..9ea38abcb 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -105,6 +105,7 @@ class AuthsTable: name: str, profile_image_url: str = "/user.png", role: str = "pending", + oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: log.info("insert_new_auth") @@ -115,7 +116,9 @@ class AuthsTable: ) result = Auth.create(**auth.model_dump()) - user = Users.insert_new_user(id, name, email, profile_image_url, role) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) if result and user: return user diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 8f600c6d5..b9b144b48 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -26,6 +26,8 @@ class User(Model): api_key = CharField(null=True, unique=True) + oauth_sub = TextField(null=True, unique=True) + class Meta: database = DB @@ -43,6 +45,8 @@ class UserModel(BaseModel): api_key: Optional[str] = None + oauth_sub: Optional[str] = None + #################### # Forms @@ -73,6 +77,7 @@ class UsersTable: email: str, profile_image_url: str = "/user.png", role: str = "pending", + oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: user = UserModel( **{ @@ -84,6 +89,7 @@ class UsersTable: "last_active_at": int(time.time()), "created_at": int(time.time()), "updated_at": int(time.time()), + "oauth_sub": oauth_sub, } ) result = User.create(**user.model_dump()) @@ -113,6 +119,13 @@ class UsersTable: except: return None + def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: + try: + user = User.get(User.oauth_sub == sub) + return UserModel(**model_to_dict(user)) + except: + return None + def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: return [ UserModel(**model_to_dict(user)) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index ce9b92061..bc8ce301a 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -1,5 +1,7 @@ import logging +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo from fastapi import Request, UploadFile, File from fastapi import Depends, HTTPException, status @@ -9,6 +11,7 @@ import re import uuid import csv +from starlette.responses import RedirectResponse from apps.webui.models.auths import ( SigninForm, @@ -33,7 +36,12 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER +from config import ( + WEBUI_AUTH, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + OAUTH_PROVIDERS, + ENABLE_OAUTH_SIGNUP, +) router = APIRouter() @@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)): } else: raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + +############################ +# OAuth Login & Callback +############################ + +oauth = OAuth() + +for provider_name, provider_config in OAUTH_PROVIDERS.items(): + oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + ) + + +@router.get("/oauth/{provider}/login") +async def oauth_login(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + redirect_uri = request.url_for("oauth_callback", provider=provider) + return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + + +@router.get("/oauth/{provider}/callback") +async def oauth_callback(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = oauth.create_client(provider) + token = await client.authorize_access_token(request) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, create a new user if signup is enabled + if ENABLE_OAUTH_SIGNUP.value: + user = Auths.insert_new_auth( + email=user_data.get("email", "").lower(), + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get("name", "User"), + profile_image_url=user_data.get("picture", "/user.png"), + role=request.app.state.config.DEFAULT_USER_ROLE, + oauth_sub=provider_sub, + ) + + if request.app.state.config.WEBHOOK_URL: + post_webhook( + request.app.state.config.WEBHOOK_URL, + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) diff --git a/backend/config.py b/backend/config.py index daa89de57..35e332bd8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) +#################################### +# OAuth config +#################################### + +ENABLE_OAUTH_SIGNUP = PersistentConfig( + "ENABLE_OAUTH_SIGNUP", + "oauth.enable_signup", + os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", +) + +OAUTH_PROVIDERS = {} + +if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"): + OAUTH_PROVIDERS["google"] = { + "client_id": os.environ.get("GOOGLE_CLIENT_ID"), + "client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"), + "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", + "scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), + } + +if ( + os.environ.get("MICROSOFT_CLIENT_ID") + and os.environ.get("MICROSOFT_CLIENT_SECRET") + and os.environ.get("MICROSOFT_CLIENT_TENANT_ID") +): + OAUTH_PROVIDERS["microsoft"] = { + "client_id": os.environ.get("MICROSOFT_CLIENT_ID"), + "client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"), + "server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration", + "scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), + } + +if ( + os.environ.get("OPENID_CLIENT_ID") + and os.environ.get("OPENID_CLIENT_SECRET") + and os.environ.get("OPENID_PROVIDER_URL") +): + OAUTH_PROVIDERS["oidc"] = { + "client_id": os.environ.get("OPENID_CLIENT_ID"), + "client_secret": os.environ.get("OPENID_CLIENT_SECRET"), + "server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"), + "scope": os.environ.get("OPENID_SCOPE", "openid email profile"), + "name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"), + } + + #################################### # Static DIR #################################### diff --git a/backend/main.py b/backend/main.py index 3d0e4fd4a..95a62adb2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -55,6 +55,7 @@ from config import ( WEBHOOK_URL, ENABLE_ADMIN_EXPORT, AppConfig, + OAUTH_PROVIDERS, ) from constants import ERROR_MESSAGES @@ -364,6 +365,13 @@ async def get_app_config(): "default_locale": default_locale, "default_models": webui_app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "oauth": { + "providers": { + name: config.get("name", name) + for name, config in OAUTH_PROVIDERS.items() + } + }, } diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index 8f4cf16a7..933097948 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -134,7 +134,12 @@ type Config = { default_models?: string[]; default_prompt_suggestions?: PromptSuggestion[]; auth_trusted_header?: boolean; - model_config?: GlobalModelConfig; + auth: boolean; + oauth: { + providers: { + [key: string]: string; + }; + }; }; type PromptSuggestion = { diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index c0ede634f..0825ae36e 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -2,6 +2,7 @@ import { onMount, tick, setContext } from 'svelte'; import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores'; import { goto } from '$app/navigation'; + import { page } from '$app/stores'; import { Toaster, toast } from 'svelte-sonner'; import { getBackendConfig } from '$lib/apis'; @@ -75,7 +76,11 @@ await goto('/auth'); } } else { - await goto('/auth'); + // Don't redirect if we're already on the auth page + // Needed because we pass in tokens from OAuth logins via URL fragments + if ($page.url.pathname !== '/auth') { + await goto('/auth'); + } } } } else { diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte index f13cbe4db..e5a40e6b7 100644 --- a/src/routes/auth/+page.svelte +++ b/src/routes/auth/+page.svelte @@ -1,12 +1,13 @@ <script> import { goto } from '$app/navigation'; - import { userSignIn, userSignUp } from '$lib/apis/auths'; + import { getSessionUser, userSignIn, userSignUp } from '$lib/apis/auths'; import Spinner from '$lib/components/common/Spinner.svelte'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_NAME, config, user } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { toast } from 'svelte-sonner'; import { generateInitialsImage, canvasPixelTest } from '$lib/utils'; + import { page } from '$app/stores'; const i18n = getContext('i18n'); @@ -21,7 +22,9 @@ if (sessionUser) { console.log(sessionUser); toast.success($i18n.t(`You're now logged in.`)); - localStorage.token = sessionUser.token; + if (sessionUser.token) { + localStorage.token = sessionUser.token; + } await user.set(sessionUser); goto('/'); } @@ -55,10 +58,35 @@ } }; + const checkOauthCallback = async () => { + if (!$page.url.hash) { + return; + } + const hash = $page.url.hash.substring(1); + if (!hash) { + return; + } + const params = new URLSearchParams(hash); + const token = params.get('token'); + if (!token) { + return; + } + const sessionUser = await getSessionUser(token).catch((error) => { + toast.error(error); + return null; + }); + if (!sessionUser) { + return; + } + localStorage.token = token; + await setSessionUser(sessionUser); + }; + onMount(async () => { if ($user !== undefined) { await goto('/'); } + await checkOauthCallback(); loaded = true; if (($config?.auth_trusted_header ?? false) || $config?.auth === false) { await signInHandler(); @@ -217,6 +245,97 @@ {/if} </div> </form> + + {#if Object.keys($config?.oauth?.providers ?? {}).length > 0 } + <div class="inline-flex items-center justify-center w-full"> + <hr class="w-64 h-px my-8 bg-gray-200 border-0 dark:bg-gray-700" /> + <span + class="absolute px-3 font-medium text-gray-900 -translate-x-1/2 bg-white left-1/2 dark:text-white dark:bg-gray-950" + >{$i18n.t('or')}</span + > + </div> + <div class="flex flex-col space-y-2"> + {#if $config?.oauth?.providers?.google } + <button + class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" + on:click={() => { + window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`; + }} + > + <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3"> + <path + fill="#EA4335" + d="M24 9.5c3.54 0 6.71 1.22 9.21 3.6l6.85-6.85C35.9 2.38 30.47 0 24 0 14.62 0 6.51 5.38 2.56 13.22l7.98 6.19C12.43 13.72 17.74 9.5 24 9.5z" + /><path + fill="#4285F4" + d="M46.98 24.55c0-1.57-.15-3.09-.38-4.55H24v9.02h12.94c-.58 2.96-2.26 5.48-4.78 7.18l7.73 6c4.51-4.18 7.09-10.36 7.09-17.65z" + /><path + fill="#FBBC05" + d="M10.53 28.59c-.48-1.45-.76-2.99-.76-4.59s.27-3.14.76-4.59l-7.98-6.19C.92 16.46 0 20.12 0 24c0 3.88.92 7.54 2.56 10.78l7.97-6.19z" + /><path + fill="#34A853" + d="M24 48c6.48 0 11.93-2.13 15.89-5.81l-7.73-6c-2.15 1.45-4.92 2.3-8.16 2.3-6.26 0-11.57-4.22-13.47-9.91l-7.98 6.19C6.51 42.62 14.62 48 24 48z" + /><path fill="none" d="M0 0h48v48H0z" /> + </svg> + <span>{$i18n.t('Continue with {{provider}}', { provider: 'Google' })}</span> + </button> + {/if} + {#if $config?.oauth?.providers?.microsoft } + <button + class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" + on:click={() => { + window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`; + }} + > + <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3"> + <rect x="1" y="1" width="9" height="9" fill="#f25022" /><rect + x="1" + y="11" + width="9" + height="9" + fill="#00a4ef" + /><rect x="11" y="1" width="9" height="9" fill="#7fba00" /><rect + x="11" + y="11" + width="9" + height="9" + fill="#ffb900" + /> + </svg> + <span>{$i18n.t('Continue with {{provider}}', { provider: 'Microsoft' })}</span> + </button> + {/if} + {#if $config?.oauth?.providers?.oidc } + <button + class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" + on:click={() => { + window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`; + }} + > + <svg + xmlns="http://www.w3.org/2000/svg" + fill="none" + viewBox="0 0 24 24" + stroke-width="1.5" + stroke="currentColor" + class="size-6 mr-3" + > + <path + stroke-linecap="round" + stroke-linejoin="round" + d="M15.75 5.25a3 3 0 0 1 3 3m3 0a6 6 0 0 1-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1 1 21.75 8.25Z" + /> + </svg> + + <span + >{$i18n.t('Continue with {{provider}}', { + provider: $config?.oauth?.providers?.oidc ?? 'SSO' + })}</span + > + </button> + {/if} + </div> + {/if} </div> {/if} </div>