diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 6da18f9f0..e0e3150b9 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -58,12 +58,6 @@ app.add_middleware( allow_headers=["*"], ) -# SessionMiddleware is used by authlib for oauth -if len(OAUTH_PROVIDERS) > 0: - app.add_middleware( - SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session" - ) - app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 598712b4d..1841369e3 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -112,9 +112,16 @@ class UsersTable: except: return None - def get_user_by_email(self, email: str) -> Optional[UserModel]: + def get_user_by_email( + self, email: str, oauth_user: bool = False + ) -> Optional[UserModel]: try: - user = User.get((User.email == email, User.oauth_sub.is_null())) + conditions = ( + (User.email == email, User.oauth_sub.is_null()) + if not oauth_user + else (User.email == email) + ) + user = User.get(conditions) return UserModel(**model_to_dict(user)) except: return None @@ -177,6 +184,18 @@ class UsersTable: except: return None + def update_user_oauth_sub_by_id( + self, id: str, oauth_sub: str + ) -> Optional[UserModel]: + try: + query = User.update(oauth_sub=oauth_sub).where(User.id == id) + query.execute() + + user = User.get(User.id == id) + return UserModel(**model_to_dict(user)) + except: + return None + def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: query = User.update(**updated).where(User.id == id) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index bc8ce301a..3c2b25b0e 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -1,7 +1,5 @@ import logging -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo from fastapi import Request, UploadFile, File from fastapi import Depends, HTTPException, status @@ -11,8 +9,6 @@ import re import uuid import csv -from starlette.responses import RedirectResponse - from apps.webui.models.auths import ( SigninForm, SignupForm, @@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from config import ( WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - OAUTH_PROVIDERS, - ENABLE_OAUTH_SIGNUP, ) router = APIRouter() @@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)): } else: raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) - - -############################ -# OAuth Login & Callback -############################ - -oauth = OAuth() - -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - ) - - -@router.get("/oauth/{provider}/login") -async def oauth_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - redirect_uri = request.url_for("oauth_callback", provider=provider) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) - - -@router.get("/oauth/{provider}/callback") -async def oauth_callback(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth.create_client(provider) - token = await client.authorize_access_token(request) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, create a new user if signup is enabled - if ENABLE_OAUTH_SIGNUP.value: - user = Auths.insert_new_auth( - email=user_data.get("email", "").lower(), - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get("name", "User"), - profile_image_url=user_data.get("picture", "/user.png"), - role=request.app.state.config.DEFAULT_USER_ROLE, - oauth_sub=provider_sub, - ) - - if request.app.state.config.WEBHOOK_URL: - post_webhook( - request.app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - else: - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), - ) - - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) diff --git a/backend/main.py b/backend/main.py index 95a62adb2..1c7563b34 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,8 @@ +import uuid from contextlib import asynccontextmanager + +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo from bs4 import BeautifulSoup import json import markdown @@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import StreamingResponse, Response +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import StreamingResponse, Response, RedirectResponse from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models from apps.openai.main import app as openai_app, get_all_models as get_openai_models @@ -31,8 +36,16 @@ import asyncio from pydantic import BaseModel from typing import List, Optional -from apps.webui.models.models import Models, ModelModel -from utils.utils import get_admin_user, get_verified_user +from apps.webui.models.auths import Auths +from apps.webui.models.models import Models +from apps.webui.models.users import Users +from utils.misc import parse_duration +from utils.utils import ( + get_admin_user, + get_verified_user, + get_password_hash, + create_token, +) from apps.rag.utils import rag_messages from config import ( @@ -56,8 +69,12 @@ from config import ( ENABLE_ADMIN_EXPORT, AppConfig, OAUTH_PROVIDERS, + ENABLE_OAUTH_SIGNUP, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + WEBUI_SECRET_KEY, ) -from constants import ERROR_MESSAGES +from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from utils.webhook import post_webhook logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) @@ -453,6 +470,103 @@ async def get_app_latest_release_version(): ) +############################ +# OAuth Login & Callback +############################ + +oauth = OAuth() + +for provider_name, provider_config in OAUTH_PROVIDERS.items(): + oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + ) + +# SessionMiddleware is used by authlib for oauth +if len(OAUTH_PROVIDERS) > 0: + app.add_middleware( + SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session" + ) + + +@app.get("/oauth/{provider}/login") +async def oauth_login(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + redirect_uri = request.url_for("oauth_callback", provider=provider) + return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + + +@app.get("/oauth/{provider}/callback") +async def oauth_callback(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = oauth.create_client(provider) + token = await client.authorize_access_token(request) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, check if merging is enabled + if OAUTH_MERGE_ACCOUNTS_BY_EMAIL: + # Check if the user exists by email + email = user_data.get("email", "").lower() + if not email: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + user = Users.get_user_by_email(user_data.get("email", "").lower(), True) + if user: + # Update the user with the new oauth sub + Users.update_user_oauth_sub_by_id(user.id, provider_sub) + + if not user: + # If the user does not exist, check if signups are enabled + if ENABLE_OAUTH_SIGNUP.value: + user = Auths.insert_new_auth( + email=user_data.get("email", "").lower(), + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get("name", "User"), + profile_image_url=user_data.get("picture", "/user.png"), + role=webui_app.state.config.DEFAULT_USER_ROLE, + oauth_sub=provider_sub, + ) + + if webui_app.state.config.WEBHOOK_URL: + post_webhook( + webui_app.state.config.WEBHOOK_URL, + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN), + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) + + @app.get("/manifest.json") async def get_manifest_json(): return { diff --git a/src/routes/auth/+page.svelte b/src/routes/auth/+page.svelte index e5a40e6b7..0fa762d5a 100644 --- a/src/routes/auth/+page.svelte +++ b/src/routes/auth/+page.svelte @@ -259,7 +259,7 @@