diff --git a/Dockerfile b/Dockerfile index 2e898dc89..fb3274a8f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ RUN npm ci COPY . . ENV APP_BUILD_HASH=${BUILD_HASH} +ENV NODE_OPTIONS="--max_old_space_size=8192" RUN npm run build ######## WebUI backend ######## diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 1d12d708e..79697cd4f 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -32,9 +32,13 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, SHOW_ADMIN_DETAILS, USER_PERMISSIONS, WEBHOOK_URL, @@ -93,6 +97,11 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index bfc9a4ded..b720de4bb 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig( ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", + "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) @@ -394,6 +394,29 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_ROLE_MANAGEMENT", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + +OAUTH_ALLOWED_ROLES = PersistentConfig( + "OAUTH_ALLOWED_ROLES", + "oauth.allowed_roles", + [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")], +) + +OAUTH_ADMIN_ROLES = PersistentConfig( + "OAUTH_ADMIN_ROLES", + "oauth.admin_roles", + [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], +) def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7086a3cc9..bc3244cdb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,4 +1,5 @@ -import base64 +import inspect +import asyncio import inspect import json import logging @@ -7,9 +8,6 @@ import os import shutil import sys import time -import uuid -import asyncio - from contextlib import asynccontextmanager from typing import Optional @@ -67,13 +65,11 @@ from open_webui.config import ( ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT, ENABLE_MODEL_FILTER, - ENABLE_OAUTH_SIGNUP, ENABLE_OLLAMA_API, ENABLE_OPENAI_API, ENV, FRONTEND_BUILD_DIR, MODEL_FILTER_LIST, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, OAUTH_PROVIDERS, ENABLE_SEARCH_QUERY, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -89,7 +85,7 @@ from open_webui.config import ( run_migrations, reset_config, ) -from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES +from open_webui.constants import ERROR_MESSAGES, TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL, @@ -103,34 +99,18 @@ from open_webui.env import ( WEBUI_URL, RESET_CONFIG_ON_START, ) -from fastapi import ( - Depends, - FastAPI, - File, - Form, - HTTPException, - Request, - UploadFile, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel -from sqlalchemy import text -from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import RedirectResponse, Response, StreamingResponse - -from open_webui.utils.security_headers import SecurityHeadersMiddleware - from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, - parse_duration, prepend_to_first_user_message_content, ) +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( moa_response_generation_template, search_query_generation_template, @@ -139,27 +119,17 @@ from open_webui.utils.task import ( ) from open_webui.utils.tools import get_tools from open_webui.utils.utils import ( - create_token, decode_token, get_admin_user, get_current_user, get_http_authorization_cred, - get_password_hash, get_verified_user, ) -from open_webui.utils.webhook import post_webhook - -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, -) if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -216,7 +186,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -231,6 +200,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.MODELS = {} + + ################################## # # ChatCompletion Middleware @@ -244,14 +215,14 @@ def get_task_model_id(default_model_id): # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL else: if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL @@ -388,7 +359,7 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -689,6 +660,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) + ################################## # # Pipeline Middleware @@ -701,15 +673,15 @@ def get_sorted_filters(model_id): model for model in app.state.MODELS.values() if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) return sorted_filters @@ -823,7 +795,6 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) - app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -869,8 +840,8 @@ async def update_embedding_function(request: Request, call_next): @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" ): upgrade = (request.headers.get("Upgrade") or "").lower() connection = (request.headers.get("Connection") or "").lower().split(",") @@ -939,8 +910,8 @@ async def get_all_models(): if custom_model.base_model_id is None: for model in models: if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() @@ -957,8 +928,8 @@ async def get_all_models(): for model in models: if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] if "pipe" in model: @@ -1756,7 +1727,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)): @app.post("/api/pipelines/upload") async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file @@ -1933,9 +1904,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use @app.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -1971,9 +1942,9 @@ async def get_pipeline_valves( @app.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -2008,10 +1979,10 @@ async def get_pipeline_valves_spec( @app.post("/api/pipelines/{pipeline_id}/valves/update") async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), ): r = None try: @@ -2135,7 +2106,7 @@ class ModelFilterConfigForm(BaseModel): @app.post("/api/config/model/filter") async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models @@ -2185,7 +2156,7 @@ async def get_app_latest_release_version(): timeout = aiohttp.ClientTimeout(total=1) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest" + "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: response.raise_for_status() data = await response.json() @@ -2201,20 +2172,6 @@ async def get_app_latest_release_version(): # OAuth Login & Callback ############################ -oauth = OAuth() - -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) - # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: app.add_middleware( @@ -2228,16 +2185,7 @@ if len(OAUTH_PROVIDERS) > 0: @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider - ) - client = oauth.create_client(provider) - if client is None: - raise HTTPException(404) - return await client.authorize_redirect(request, redirect_uri) + return await oauth_manager.handle_login(provider, request) # OAuth login logic is as follows: @@ -2248,116 +2196,7 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is alreayd taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth.create_client(provider) - try: - token = await client.authorize_access_token(request) - except Exception as e: - log.warning(f"OAuth callback error: {e}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() - # We currently mandate that email addresses are provided - if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, check if merging is enabled - if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: - # Check if the user exists by email - user = Users.get_user_by_email(email) - if user: - # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) - - if not user: - # If the user does not exist, check if signups are enabled - if ENABLE_OAUTH_SIGNUP.value: - # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) - if existing_user: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - - picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") - if picture_url: - # Download the profile image into a base64 string - try: - async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: - picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) - guessed_mime_type = mimetypes.guess_type(picture_url)[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") - picture_url = "" - if not picture_url: - picture_url = "/user.png" - username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = ( - "admin" - if Users.get_num_users() == 0 - else webui_app.state.config.DEFAULT_USER_ROLE - ) - user = Auths.insert_new_auth( - email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get(username_claim, "User"), - profile_image_url=picture_url, - role=role, - oauth_sub=provider_sub, - ) - - if webui_app.state.config.WEBHOOK_URL: - post_webhook( - webui_app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - else: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN), - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=jwt_token, - httponly=True, # Ensures the cookie is not accessible via JavaScript - ) - - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return await oauth_manager.handle_callback(provider, request, response) @app.get("/manifest.json") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py new file mode 100644 index 000000000..e15edc0a6 --- /dev/null +++ b/backend/open_webui/utils/oauth.py @@ -0,0 +1,243 @@ +import base64 +import logging +import mimetypes +import uuid + +import aiohttp +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo +from fastapi import ( + HTTPException, + status, +) +from starlette.responses import RedirectResponse + +from open_webui.apps.webui.models.auths import Auths +from open_webui.apps.webui.models.users import Users +from open_webui.config import ( + DEFAULT_USER_ROLE, + ENABLE_OAUTH_SIGNUP, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PROVIDERS, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, +) +from open_webui.constants import ERROR_MESSAGES +from open_webui.utils.misc import parse_duration +from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.webhook import post_webhook + +log = logging.getLogger(__name__) + +auth_manager_config = AppConfig() +auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP +auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL +auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.WEBHOOK_URL = WEBHOOK_URL +auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + + +class OAuthManager: + def __init__(self): + self.oauth = OAuth() + for provider_name, provider_config in OAUTH_PROVIDERS.items(): + self.oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + redirect_uri=provider_config["redirect_uri"], + ) + + def get_client(self, provider_name): + return self.oauth.create_client(provider_name) + + def get_user_role(self, user, user_data): + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" + + if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM + oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES + oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = auth_manager_config.DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + + async def handle_login(self, provider, request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + # If the provider has a custom redirect URL, use that, otherwise automatically generate one + redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( + "oauth_callback", provider=provider + ) + client = self.get_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) + + async def handle_callback(self, provider, request, response): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = self.get_client(provider) + try: + token = await client.authorize_access_token(request) + except Exception as e: + log.warning(f"OAuth callback error: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + log.warning(f"OAuth callback failed, sub is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM + email = user_data.get(email_claim, "").lower() + # We currently mandate that email addresses are provided + if not email: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, check if merging is enabled + if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: + # Check if the user exists by email + user = Users.get_user_by_email(email) + if user: + # Update the user with the new oauth sub + Users.update_user_oauth_sub_by_id(user.id, provider_sub) + + if user: + determined_role = self.get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) + + if not user: + # If the user does not exist, check if signups are enabled + if auth_manager_config.ENABLE_OAUTH_SIGNUP.value: + # Check if an existing user with the same email already exists + existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + if existing_user: + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) + + picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") + if picture_url: + # Download the profile image into a base64 string + try: + async with aiohttp.ClientSession() as session: + async with session.get(picture_url) as resp: + picture = await resp.read() + base64_encoded_picture = base64.b64encode(picture).decode( + "utf-8" + ) + guessed_mime_type = mimetypes.guess_type(picture_url)[0] + if guessed_mime_type is None: + # assume JPG, browsers are tolerant enough of image formats + guessed_mime_type = "image/jpeg" + picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" + except Exception as e: + log.error(f"Error downloading profile image '{picture_url}': {e}") + picture_url = "" + if not picture_url: + picture_url = "/user.png" + username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM + + role = self.get_user_role(None, user_data) + + user = Auths.insert_new_auth( + email=email, + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get(username_claim, "User"), + profile_image_url=picture_url, + role=role, + oauth_sub=provider_sub, + ) + + if auth_manager_config.WEBHOOK_URL: + post_webhook( + auth_manager_config.WEBHOOK_URL, + auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), + ) + + # Set the cookie token + response.set_cookie( + key="token", + value=jwt_token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) + +oauth_manager = OAuthManager() \ No newline at end of file