diff --git a/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py new file mode 100644 index 000000000..70dfeccf0 --- /dev/null +++ b/backend/apps/webui/internal/migrations/011_add_user_oauth_sub.py @@ -0,0 +1,49 @@ +"""Peewee migrations -- 011_add_user_oauth_sub.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index d736cef9a..6da18f9f0 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -1,6 +1,8 @@ from fastapi import FastAPI, Depends from fastapi.routing import APIRoute from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.sessions import SessionMiddleware + from apps.webui.routers import ( auths, users, @@ -24,6 +26,8 @@ from config import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, JWT_EXPIRES_IN, AppConfig, + WEBUI_SECRET_KEY, + OAUTH_PROVIDERS, ) app = FastAPI() @@ -54,6 +58,12 @@ app.add_middleware( allow_headers=["*"], ) +# SessionMiddleware is used by authlib for oauth +if len(OAUTH_PROVIDERS) > 0: + app.add_middleware( + SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session" + ) + app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index e3b659e43..9ea38abcb 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -105,6 +105,7 @@ class AuthsTable: name: str, profile_image_url: str = "/user.png", role: str = "pending", + oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: log.info("insert_new_auth") @@ -115,7 +116,9 @@ class AuthsTable: ) result = Auth.create(**auth.model_dump()) - user = Users.insert_new_user(id, name, email, profile_image_url, role) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) if result and user: return user diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 8f600c6d5..b9b144b48 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -26,6 +26,8 @@ class User(Model): api_key = CharField(null=True, unique=True) + oauth_sub = TextField(null=True, unique=True) + class Meta: database = DB @@ -43,6 +45,8 @@ class UserModel(BaseModel): api_key: Optional[str] = None + oauth_sub: Optional[str] = None + #################### # Forms @@ -73,6 +77,7 @@ class UsersTable: email: str, profile_image_url: str = "/user.png", role: str = "pending", + oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: user = UserModel( **{ @@ -84,6 +89,7 @@ class UsersTable: "last_active_at": int(time.time()), "created_at": int(time.time()), "updated_at": int(time.time()), + "oauth_sub": oauth_sub, } ) result = User.create(**user.model_dump()) @@ -113,6 +119,13 @@ class UsersTable: except: return None + def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: + try: + user = User.get(User.oauth_sub == sub) + return UserModel(**model_to_dict(user)) + except: + return None + def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: return [ UserModel(**model_to_dict(user)) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index ce9b92061..bc8ce301a 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -1,5 +1,7 @@ import logging +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo from fastapi import Request, UploadFile, File from fastapi import Depends, HTTPException, status @@ -9,6 +11,7 @@ import re import uuid import csv +from starlette.responses import RedirectResponse from apps.webui.models.auths import ( SigninForm, @@ -33,7 +36,12 @@ from utils.utils import ( from utils.misc import parse_duration, validate_email_format from utils.webhook import post_webhook from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER +from config import ( + WEBUI_AUTH, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + OAUTH_PROVIDERS, + ENABLE_OAUTH_SIGNUP, +) router = APIRouter() @@ -373,3 +381,82 @@ async def get_api_key(user=Depends(get_current_user)): } else: raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + +############################ +# OAuth Login & Callback +############################ + +oauth = OAuth() + +for provider_name, provider_config in OAUTH_PROVIDERS.items(): + oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + ) + + +@router.get("/oauth/{provider}/login") +async def oauth_login(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + redirect_uri = request.url_for("oauth_callback", provider=provider) + return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + + +@router.get("/oauth/{provider}/callback") +async def oauth_callback(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = oauth.create_client(provider) + token = await client.authorize_access_token(request) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, create a new user if signup is enabled + if ENABLE_OAUTH_SIGNUP.value: + user = Auths.insert_new_auth( + email=user_data.get("email", "").lower(), + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get("name", "User"), + profile_image_url=user_data.get("picture", "/user.png"), + role=request.app.state.config.DEFAULT_USER_ROLE, + oauth_sub=provider_sub, + ) + + if request.app.state.config.WEBHOOK_URL: + post_webhook( + request.app.state.config.WEBHOOK_URL, + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) diff --git a/backend/config.py b/backend/config.py index daa89de57..35e332bd8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -285,6 +285,52 @@ JWT_EXPIRES_IN = PersistentConfig( "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") ) +#################################### +# OAuth config +#################################### + +ENABLE_OAUTH_SIGNUP = PersistentConfig( + "ENABLE_OAUTH_SIGNUP", + "oauth.enable_signup", + os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", +) + +OAUTH_PROVIDERS = {} + +if os.environ.get("GOOGLE_CLIENT_ID") and os.environ.get("GOOGLE_CLIENT_SECRET"): + OAUTH_PROVIDERS["google"] = { + "client_id": os.environ.get("GOOGLE_CLIENT_ID"), + "client_secret": os.environ.get("GOOGLE_CLIENT_SECRET"), + "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", + "scope": os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), + } + +if ( + os.environ.get("MICROSOFT_CLIENT_ID") + and os.environ.get("MICROSOFT_CLIENT_SECRET") + and os.environ.get("MICROSOFT_CLIENT_TENANT_ID") +): + OAUTH_PROVIDERS["microsoft"] = { + "client_id": os.environ.get("MICROSOFT_CLIENT_ID"), + "client_secret": os.environ.get("MICROSOFT_CLIENT_SECRET"), + "server_metadata_url": f"https://login.microsoftonline.com/{os.environ.get('MICROSOFT_CLIENT_TENANT_ID')}/v2.0/.well-known/openid-configuration", + "scope": os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), + } + +if ( + os.environ.get("OPENID_CLIENT_ID") + and os.environ.get("OPENID_CLIENT_SECRET") + and os.environ.get("OPENID_PROVIDER_URL") +): + OAUTH_PROVIDERS["oidc"] = { + "client_id": os.environ.get("OPENID_CLIENT_ID"), + "client_secret": os.environ.get("OPENID_CLIENT_SECRET"), + "server_metadata_url": os.environ.get("OPENID_PROVIDER_URL"), + "scope": os.environ.get("OPENID_SCOPE", "openid email profile"), + "name": os.environ.get("OPENID_PROVIDER_NAME", "SSO"), + } + + #################################### # Static DIR #################################### diff --git a/backend/main.py b/backend/main.py index 3d0e4fd4a..95a62adb2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -55,6 +55,7 @@ from config import ( WEBHOOK_URL, ENABLE_ADMIN_EXPORT, AppConfig, + OAUTH_PROVIDERS, ) from constants import ERROR_MESSAGES @@ -364,6 +365,13 @@ async def get_app_config(): "default_locale": default_locale, "default_models": webui_app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS, + "trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER), + "oauth": { + "providers": { + name: config.get("name", name) + for name, config in OAUTH_PROVIDERS.items() + } + }, } diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index 8f4cf16a7..933097948 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -134,7 +134,12 @@ type Config = { default_models?: string[]; default_prompt_suggestions?: PromptSuggestion[]; auth_trusted_header?: boolean; - model_config?: GlobalModelConfig; + auth: boolean; + oauth: { + providers: { + [key: string]: string; + }; + }; }; type PromptSuggestion = { diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index c0ede634f..0825ae36e 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -2,6 +2,7 @@ import { onMount, tick, setContext } from 'svelte'; import { config, user, theme, WEBUI_NAME, mobile } from '$lib/stores'; import { goto } from '$app/navigation'; + import { page } from '$app/stores'; import { Toaster, toast } from 'svelte-sonner'; import { getBackendConfig } from '$lib/apis'; @@ -75,7 +76,11 @@ await goto('/auth'); } } else { - await goto('/auth'); + // Don't redirect if we're already on the auth page + // Needed because we pass in tokens from OAuth logins via URL fragments + if ($page.url.pathname !== '/auth') { + await goto('/auth'); + } } } } else { diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte index f13cbe4db..e5a40e6b7 100644 --- a/src/routes/auth/+page.svelte +++ b/src/routes/auth/+page.svelte @@ -1,12 +1,13 @@