From 0210a105bfb4527f31e2bf1fb5086b0341ffe156 Mon Sep 17 00:00:00 2001
From: Jun Siang Cheah <git@jscheah.me>
Date: Sun, 26 May 2024 08:37:09 +0100
Subject: [PATCH] feat: experimental SSO support for Google, Microsoft, and
 OIDC

---
 .../migrations/011_add_user_oauth_sub.py      |  49 +++++++
 backend/apps/webui/main.py                    |  10 ++
 backend/apps/webui/models/auths.py            |   5 +-
 backend/apps/webui/models/users.py            |  13 ++
 backend/apps/webui/routers/auths.py           |  89 ++++++++++++-
 backend/config.py                             |  46 +++++++
 backend/main.py                               |   8 ++
 src/lib/stores/index.ts                       |   7 +-
 src/routes/+layout.svelte                     |   7 +-
 src/routes/auth/+page.svelte                  | 123 +++++++++++++++++-
 10 files changed, 351 insertions(+), 6 deletions(-)
 create mode 100644 backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py

diff --git a/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py
new file mode 100644
index 000000000..70dfeccf0
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py
@@ -0,0 +1,49 @@
+"""Peewee migrations -- 011_add_user_oauth_sub.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    migrator.add_fields(
+        "user",
+        oauth_sub=pw.TextField(null=True, unique=True),
+    )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_fields("user", "oauth_sub")
diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py
index d736cef9a..6da18f9f0 100644
--- a/backend/apps/webui/main.py
+++ b/backend/apps/webui/main.py
@@ -1,6 +1,8 @@
 from fastapi import FastAPI, Depends
 from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
+from starlette.middleware.sessions import SessionMiddleware
+
 from apps.webui.routers import (
     auths,
     users,
@@ -24,6 +26,8 @@ from config import (
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     JWT_EXPIRES_IN,
     AppConfig,
+    WEBUI_SECRET_KEY,
+    OAUTH_PROVIDERS,
 )
 
 app = FastAPI()
@@ -54,6 +58,12 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+# SessionMiddleware is used by authlib for oauth
+if len(OAUTH_PROVIDERS) > 0:
+    app.add_middleware(
+        SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
+    )
+
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py
index e3b659e43..9ea38abcb 100644
--- a/backend/apps/webui/models/auths.py
+++ b/backend/apps/webui/models/auths.py
@@ -105,6 +105,7 @@ class AuthsTable:
         name: str,
         profile_image_url: str = "/user.png",
         role: str = "pending",
+        oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
         log.info("insert_new_auth")
 
@@ -115,7 +116,9 @@ class AuthsTable:
         )
         result = Auth.create(**auth.model_dump())
 
-        user = Users.insert_new_user(id, name, email, profile_image_url, role)
+        user = Users.insert_new_user(
+            id, name, email, profile_image_url, role, oauth_sub
+        )
 
         if result and user:
             return user
diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py
index 8f600c6d5..b9b144b48 100644
--- a/backend/apps/webui/models/users.py
+++ b/backend/apps/webui/models/users.py
@@ -26,6 +26,8 @@ class User(Model):
 
     api_key = CharField(null=True, unique=True)
 
+    oauth_sub = TextField(null=True, unique=True)
+
     class Meta:
         database = DB
 
@@ -43,6 +45,8 @@ class UserModel(BaseModel):
 
     api_key: Optional[str] = None
 
+    oauth_sub: Optional[str] = None
+
 
 ####################
 # Forms
@@ -73,6 +77,7 @@ class UsersTable:
         email: str,
         profile_image_url: str = "/user.png",
         role: str = "pending",
+        oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
         user = UserModel(
             **{
@@ -84,6 +89,7 @@ class UsersTable:
                 "last_active_at": int(time.time()),
                 "created_at": int(time.time()),
                 "updated_at": int(time.time()),
+                "oauth_sub": oauth_sub,
             }
         )
         result = User.create(**user.model_dump())
@@ -113,6 +119,13 @@ class UsersTable:
         except:
             return None
 
+    def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
+        try:
+            user = User.get(User.oauth_sub == sub)
+            return UserModel(**model_to_dict(user))
+        except:
+            return None
+
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
             UserModel(**model_to_dict(user))
diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py
index ce9b92061..bc8ce301a 100644
--- a/backend/apps/webui/routers/auths.py
+++ b/backend/apps/webui/routers/auths.py
@@ -1,5 +1,7 @@
 import logging
 
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
 from fastapi import Request, UploadFile, File
 from fastapi import Depends, HTTPException, status
 
@@ -9,6 +11,7 @@ import re
 import uuid
 import csv
 
+from starlette.responses import RedirectResponse
 
 from apps.webui.models.auths import (
     SigninForm,
@@ -33,7 +36,12 @@ from utils.utils import (
 from utils.misc import parse_duration, validate_email_format
 from utils.webhook import post_webhook
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
-from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
+from config import (
+    WEBUI_AUTH,
+    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
+    OAUTH_PROVIDERS,
+    ENABLE_OAUTH_SIGNUP,
+)
 
 router = APIRouter()
 
@@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)):
         }
     else:
         raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+
+############################
+# OAuth Login & Callback
+############################
+
+oauth = OAuth()
+
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
+    oauth.register(
+        name=provider_name,
+        client_id=provider_config["client_id"],
+        client_secret=provider_config["client_secret"],
+        server_metadata_url=provider_config["server_metadata_url"],
+        client_kwargs={
+            "scope": provider_config["scope"],
+        },
+    )
+
+
+@router.get("/oauth/{provider}/login")
+async def oauth_login(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    redirect_uri = request.url_for("oauth_callback", provider=provider)
+    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+
+
+@router.get("/oauth/{provider}/callback")
+async def oauth_callback(provider: str, request: Request):
+    if provider not in OAUTH_PROVIDERS:
+        raise HTTPException(404)
+    client = oauth.create_client(provider)
+    token = await client.authorize_access_token(request)
+    user_data: UserInfo = token["userinfo"]
+
+    sub = user_data.get("sub")
+    if not sub:
+        raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+    provider_sub = f"{provider}@{sub}"
+
+    # Check if the user exists
+    user = Users.get_user_by_oauth_sub(provider_sub)
+
+    if not user:
+        # If the user does not exist, create a new user if signup is enabled
+        if ENABLE_OAUTH_SIGNUP.value:
+            user = Auths.insert_new_auth(
+                email=user_data.get("email", "").lower(),
+                password=get_password_hash(
+                    str(uuid.uuid4())
+                ),  # Random password, not used
+                name=user_data.get("name", "User"),
+                profile_image_url=user_data.get("picture", "/user.png"),
+                role=request.app.state.config.DEFAULT_USER_ROLE,
+                oauth_sub=provider_sub,
+            )
+
+            if request.app.state.config.WEBHOOK_URL:
+                post_webhook(
+                    request.app.state.config.WEBHOOK_URL,
+                    WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                    {
+                        "action": "signup",
+                        "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+                        "user": user.model_dump_json(exclude_none=True),
+                    },
+                )
+        else:
+            raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+    jwt_token = create_token(
+        data={"id": user.id},
+        expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+    )
+
+    # Redirect back to the frontend with the JWT token
+    redirect_url = f"{request.base_url}auth#token={jwt_token}"
+    return RedirectResponse(url=redirect_url)
diff --git a/backend/config.py b/backend/config.py
index daa89de57..35e332bd8 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig(
     "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
 )
 
+####################################
+# OAuth config
+####################################
+
+ENABLE_OAUTH_SIGNUP = PersistentConfig(
+    "ENABLE_OAUTH_SIGNUP",
+    "oauth.enable_signup",
+    os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
+)
+
+OAUTH_PROVIDERS = {}
+
+if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"):
+    OAUTH_PROVIDERS["google"] = {
+        "client_id": os.environ.get("GOOGLE_CLIENT_ID"),
+        "client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"),
+        "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
+        "scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
+    }
+
+if (
+    os.environ.get("MICROSOFT_CLIENT_ID")
+    and os.environ.get("MICROSOFT_CLIENT_SECRET")
+    and os.environ.get("MICROSOFT_CLIENT_TENANT_ID")
+):
+    OAUTH_PROVIDERS["microsoft"] = {
+        "client_id": os.environ.get("MICROSOFT_CLIENT_ID"),
+        "client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"),
+        "server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration",
+        "scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
+    }
+
+if (
+    os.environ.get("OPENID_CLIENT_ID")
+    and os.environ.get("OPENID_CLIENT_SECRET")
+    and os.environ.get("OPENID_PROVIDER_URL")
+):
+    OAUTH_PROVIDERS["oidc"] = {
+        "client_id": os.environ.get("OPENID_CLIENT_ID"),
+        "client_secret": os.environ.get("OPENID_CLIENT_SECRET"),
+        "server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"),
+        "scope": os.environ.get("OPENID_SCOPE", "openid email profile"),
+        "name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"),
+    }
+
+
 ####################################
 # Static DIR
 ####################################
diff --git a/backend/main.py b/backend/main.py
index 3d0e4fd4a..95a62adb2 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -55,6 +55,7 @@ from config import (
     WEBHOOK_URL,
     ENABLE_ADMIN_EXPORT,
     AppConfig,
+    OAUTH_PROVIDERS,
 )
 from constants import ERROR_MESSAGES
 
@@ -364,6 +365,13 @@ async def get_app_config():
         "default_locale": default_locale,
         "default_models": webui_app.state.config.DEFAULT_MODELS,
         "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
+        "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
+        "oauth": {
+            "providers": {
+                name: config.get("name", name)
+                for name, config in OAUTH_PROVIDERS.items()
+            }
+        },
     }
 
 
diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts
index 8f4cf16a7..933097948 100644
--- a/src/lib/stores/index.ts
+++ b/src/lib/stores/index.ts
@@ -134,7 +134,12 @@ type Config = {
 	default_models?: string[];
 	default_prompt_suggestions?: PromptSuggestion[];
 	auth_trusted_header?: boolean;
-	model_config?: GlobalModelConfig;
+	auth: boolean;
+	oauth: {
+		providers: {
+			[key: string]: string;
+		};
+	};
 };
 
 type PromptSuggestion = {
diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte
index c0ede634f..0825ae36e 100644
--- a/src/routes/+layout.svelte
+++ b/src/routes/+layout.svelte
@@ -2,6 +2,7 @@
 	import { onMount, tick, setContext } from 'svelte';
 	import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores';
 	import { goto } from '$app/navigation';
+	import { page } from '$app/stores';
 	import { Toaster, toast } from 'svelte-sonner';
 
 	import { getBackendConfig } from '$lib/apis';
@@ -75,7 +76,11 @@
 						await goto('/auth');
 					}
 				} else {
-					await goto('/auth');
+					// Don't redirect if we're already on the auth page
+					// Needed because we pass in tokens from OAuth logins via URL fragments
+					if ($page.url.pathname !== '/auth') {
+						await goto('/auth');
+					}
 				}
 			}
 		} else {
diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte
index f13cbe4db..e5a40e6b7 100644
--- a/src/routes/auth/+page.svelte
+++ b/src/routes/auth/+page.svelte
@@ -1,12 +1,13 @@
 <script>
 	import { goto } from '$app/navigation';
-	import { userSignIn, userSignUp } from '$lib/apis/auths';
+	import { getSessionUser, userSignIn, userSignUp } from '$lib/apis/auths';
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 	import { WEBUI_NAME, config, user } from '$lib/stores';
 	import { onMount, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	import { generateInitialsImage, canvasPixelTest } from '$lib/utils';
+	import { page } from '$app/stores';
 
 	const i18n = getContext('i18n');
 
@@ -21,7 +22,9 @@
 		if (sessionUser) {
 			console.log(sessionUser);
 			toast.success($i18n.t(`You're now logged in.`));
-			localStorage.token = sessionUser.token;
+			if (sessionUser.token) {
+				localStorage.token = sessionUser.token;
+			}
 			await user.set(sessionUser);
 			goto('/');
 		}
@@ -55,10 +58,35 @@
 		}
 	};
 
+	const checkOauthCallback = async () => {
+		if (!$page.url.hash) {
+			return;
+		}
+		const hash = $page.url.hash.substring(1);
+		if (!hash) {
+			return;
+		}
+		const params = new URLSearchParams(hash);
+		const token = params.get('token');
+		if (!token) {
+			return;
+		}
+		const sessionUser = await getSessionUser(token).catch((error) => {
+			toast.error(error);
+			return null;
+		});
+		if (!sessionUser) {
+			return;
+		}
+		localStorage.token = token;
+		await setSessionUser(sessionUser);
+	};
+
 	onMount(async () => {
 		if ($user !== undefined) {
 			await goto('/');
 		}
+		await checkOauthCallback();
 		loaded = true;
 		if (($config?.auth_trusted_header ?? false) || $config?.auth === false) {
 			await signInHandler();
@@ -217,6 +245,97 @@
 							{/if}
 						</div>
 					</form>
+
+					{#if Object.keys($config?.oauth?.providers ?? {}).length > 0 }
+						<div class="inline-flex items-center justify-center w-full">
+							<hr class="w-64 h-px my-8 bg-gray-200 border-0 dark:bg-gray-700" />
+							<span
+								class="absolute px-3 font-medium text-gray-900 -translate-x-1/2 bg-white left-1/2 dark:text-white dark:bg-gray-950"
+								>{$i18n.t('or')}</span
+							>
+						</div>
+						<div class="flex flex-col space-y-2">
+							{#if $config?.oauth?.providers?.google }
+								<button
+									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
+									on:click={() => {
+										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
+									}}
+								>
+									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
+										<path
+											fill="#EA4335"
+											d="M24 9.5c3.54 0 6.71 1.22 9.21 3.6l6.85-6.85C35.9 2.38 30.47 0 24 0 14.62 0 6.51 5.38 2.56 13.22l7.98 6.19C12.43 13.72 17.74 9.5 24 9.5z"
+										/><path
+											fill="#4285F4"
+											d="M46.98 24.55c0-1.57-.15-3.09-.38-4.55H24v9.02h12.94c-.58 2.96-2.26 5.48-4.78 7.18l7.73 6c4.51-4.18 7.09-10.36 7.09-17.65z"
+										/><path
+											fill="#FBBC05"
+											d="M10.53 28.59c-.48-1.45-.76-2.99-.76-4.59s.27-3.14.76-4.59l-7.98-6.19C.92 16.46 0 20.12 0 24c0 3.88.92 7.54 2.56 10.78l7.97-6.19z"
+										/><path
+											fill="#34A853"
+											d="M24 48c6.48 0 11.93-2.13 15.89-5.81l-7.73-6c-2.15 1.45-4.92 2.3-8.16 2.3-6.26 0-11.57-4.22-13.47-9.91l-7.98 6.19C6.51 42.62 14.62 48 24 48z"
+										/><path fill="none" d="M0 0h48v48H0z" />
+									</svg>
+									<span>{$i18n.t('Continue with {{provider}}', { provider: 'Google' })}</span>
+								</button>
+							{/if}
+							{#if $config?.oauth?.providers?.microsoft }
+								<button
+									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
+									on:click={() => {
+										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
+									}}
+								>
+									<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
+										<rect x="1" y="1" width="9" height="9" fill="#f25022" /><rect
+											x="1"
+											y="11"
+											width="9"
+											height="9"
+											fill="#00a4ef"
+										/><rect x="11" y="1" width="9" height="9" fill="#7fba00" /><rect
+											x="11"
+											y="11"
+											width="9"
+											height="9"
+											fill="#ffb900"
+										/>
+									</svg>
+									<span>{$i18n.t('Continue with {{provider}}', { provider: 'Microsoft' })}</span>
+								</button>
+							{/if}
+							{#if $config?.oauth?.providers?.oidc }
+								<button
+									class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
+									on:click={() => {
+										window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
+									}}
+								>
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										fill="none"
+										viewBox="0 0 24 24"
+										stroke-width="1.5"
+										stroke="currentColor"
+										class="size-6 mr-3"
+									>
+										<path
+											stroke-linecap="round"
+											stroke-linejoin="round"
+											d="M15.75 5.25a3 3 0 0 1 3 3m3 0a6 6 0 0 1-7.029 5.912c-.563-.097-1.159.026-1.563.43L10.5 17.25H8.25v2.25H6v2.25H2.25v-2.818c0-.597.237-1.17.659-1.591l6.499-6.499c.404-.404.527-1 .43-1.563A6 6 0 1 1 21.75 8.25Z"
+										/>
+									</svg>
+
+									<span
+										>{$i18n.t('Continue with {{provider}}', {
+											provider: $config?.oauth?.providers?.oidc ?? 'SSO'
+										})}</span
+									>
+								</button>
+							{/if}
+						</div>
+					{/if}
 				</div>
 			{/if}
 		</div>