diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index f835e3175..cb38a53eb 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel): class GenerateEmbedForm(BaseModel): model: str - input: list[str]|str + input: list[str] | str truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index c6d95bd52..7782671a2 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -110,9 +110,8 @@ class ChromaClient: def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. collection = self.client.get_or_create_collection( - name=collection_name, - metadata={"hnsw:space": "cosine"} - ) + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] @@ -131,9 +130,8 @@ class ChromaClient: def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. collection = self.client.get_or_create_collection( - name=collection_name, - metadata={"hnsw:space": "cosine"} - ) + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py index 70908dc33..c1e06872f 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI NO_LIMIT = 999999999 + class QdrantClient: def __init__(self): self.collection_prefix = "open-webui" @@ -38,15 +39,15 @@ class QdrantClient: collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" self.client.create_collection( collection_name=collection_name_with_prefix, - vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), + vectors_config=models.VectorParams( + size=dimension, distance=models.Distance.COSINE + ), ) print(f"collection {collection_name_with_prefix} successfully created!") def _create_collection_if_not_exists(self, collection_name, dimension): - if not self.has_collection( - collection_name=collection_name - ): + if not self.has_collection(collection_name=collection_name): self._create_collection( collection_name=collection_name, dimension=dimension ) @@ -56,22 +57,23 @@ class QdrantClient: PointStruct( id=item["id"], vector=item["vector"], - payload={ - "text": item["text"], - "metadata": item["metadata"] - }, + payload={"text": item["text"], "metadata": item["metadata"]}, ) for item in items ] def has_collection(self, collection_name: str) -> bool: - return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}") + return self.client.collection_exists( + f"{self.collection_prefix}_{collection_name}" + ) def delete_collection(self, collection_name: str): - return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}") + return self.client.delete_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. if limit is None: @@ -87,7 +89,7 @@ class QdrantClient: ids=get_result.ids, documents=get_result.documents, metadatas=get_result.metadatas, - distances=[[point.score for point in query_response.points]] + distances=[[point.score for point in query_response.points]], ) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): @@ -101,7 +103,10 @@ class QdrantClient: field_conditions = [] for key, value in filter.items(): field_conditions.append( - models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) + models.FieldCondition( + key=f"metadata.{key}", match=models.MatchValue(value=value) + ) + ) points = self.client.query_points( collection_name=f"{self.collection_prefix}_{collection_name}", @@ -117,7 +122,7 @@ class QdrantClient: # Get all the items in the collection. points = self.client.query_points( collection_name=f"{self.collection_prefix}_{collection_name}", - limit=NO_LIMIT # otherwise qdrant would set limit to 10! + limit=NO_LIMIT, # otherwise qdrant would set limit to 10! ) return self._result_to_get_result(points.points) @@ -134,10 +139,10 @@ class QdrantClient: return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) def delete( - self, - collection_name: str, - ids: Optional[list[str]] = None, - filter: Optional[dict] = None, + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, ): # Delete the items from the collection based on the ids. field_conditions = [] @@ -162,9 +167,7 @@ class QdrantClient: return self.client.delete( collection_name=f"{self.collection_prefix}_{collection_name}", points_selector=models.FilterSelector( - filter=models.Filter( - must=field_conditions - ) + filter=models.Filter(must=field_conditions) ), ) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 815df9777..11346ba55 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -33,9 +33,13 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, SHOW_ADMIN_DETAILS, USER_PERMISSIONS, WEBHOOK_URL, @@ -94,6 +98,11 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 9ec275bac..496f2395f 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig( ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", + "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) @@ -394,6 +394,33 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_ROLE_MANAGEMENT", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + +OAUTH_ALLOWED_ROLES = PersistentConfig( + "OAUTH_ALLOWED_ROLES", + "oauth.allowed_roles", + [ + role.strip() + for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",") + ], +) + +OAUTH_ADMIN_ROLES = PersistentConfig( + "OAUTH_ADMIN_ROLES", + "oauth.admin_roles", + [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6dd7a5079..5b3ca7e64 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,4 +1,4 @@ -import base64 +import asyncio import inspect import json import logging @@ -7,104 +7,11 @@ import os import shutil import sys import time -import uuid -import asyncio - from contextlib import asynccontextmanager from typing import Optional import aiohttp import requests - -from open_webui.apps.ollama.main import ( - app as ollama_app, - get_all_models as get_ollama_models, - generate_chat_completion as generate_ollama_chat_completion, - generate_openai_chat_completion as generate_ollama_openai_chat_completion, - GenerateChatCompletionForm, -) -from open_webui.apps.openai.main import ( - app as openai_app, - generate_chat_completion as generate_openai_chat_completion, - get_all_models as get_openai_models, -) - -from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_rag_context, rag_template - -from open_webui.apps.socket.main import ( - app as socket_app, - periodic_usage_pool_cleanup, - get_event_call, - get_event_emitter, -) - -from open_webui.apps.webui.main import ( - app as webui_app, - generate_function_chat_completion, - get_pipe_models, -) -from open_webui.apps.webui.internal.db import Session - -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users - -from open_webui.apps.webui.utils import load_function_module_by_id - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app - -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo - - -from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_EXPORT, - ENABLE_MODEL_FILTER, - ENABLE_OAUTH_SIGNUP, - ENABLE_OLLAMA_API, - ENABLE_OPENAI_API, - ENV, - FRONTEND_BUILD_DIR, - MODEL_FILTER_LIST, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, - OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - STATIC_DIR, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - TITLE_GENERATION_PROMPT_TEMPLATE, - TAGS_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_NAME, - AppConfig, - run_migrations, - reset_config, -) -from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES -from open_webui.env import ( - CHANGELOG, - GLOBAL_LOG_LEVEL, - SAFE_MODE, - SRC_LOG_LEVELS, - VERSION, - WEBUI_BUILD_HASH, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, - RESET_CONFIG_ON_START, - OFFLINE_MODE, -) from fastapi import ( Depends, FastAPI, @@ -123,16 +30,93 @@ from sqlalchemy import text from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import RedirectResponse, Response, StreamingResponse - -from open_webui.utils.security_headers import SecurityHeadersMiddleware +from starlette.responses import Response, StreamingResponse +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app +from open_webui.apps.ollama.main import ( + app as ollama_app, + get_all_models as get_ollama_models, + generate_chat_completion as generate_ollama_chat_completion, + GenerateChatCompletionForm, +) +from open_webui.apps.openai.main import ( + app as openai_app, + generate_chat_completion as generate_openai_chat_completion, + get_all_models as get_openai_models, +) +from open_webui.apps.retrieval.main import app as retrieval_app +from open_webui.apps.retrieval.utils import get_rag_context, rag_template +from open_webui.apps.socket.main import ( + app as socket_app, + periodic_usage_pool_cleanup, + get_event_call, + get_event_emitter, +) +from open_webui.apps.webui.internal.db import Session +from open_webui.apps.webui.main import ( + app as webui_app, + generate_function_chat_completion, + get_pipe_models, +) +from open_webui.apps.webui.models.functions import Functions +from open_webui.apps.webui.models.models import Models +from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.apps.webui.utils import load_function_module_by_id +from open_webui.config import ( + CACHE_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + ENABLE_MODEL_FILTER, + ENABLE_OLLAMA_API, + ENABLE_OPENAI_API, + ENV, + FRONTEND_BUILD_DIR, + MODEL_FILTER_LIST, + OAUTH_PROVIDERS, + ENABLE_SEARCH_QUERY, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + STATIC_DIR, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TITLE_GENERATION_PROMPT_TEMPLATE, + TAGS_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + WEBHOOK_URL, + WEBUI_AUTH, + WEBUI_NAME, + AppConfig, + reset_config, +) +from open_webui.constants import TASKS +from open_webui.env import ( + CHANGELOG, + GLOBAL_LOG_LEVEL, + SAFE_MODE, + SRC_LOG_LEVELS, + VERSION, + WEBUI_BUILD_HASH, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, + WEBUI_URL, + RESET_CONFIG_ON_START, + OFFLINE_MODE, +) from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, - parse_duration, prepend_to_first_user_message_content, ) +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( moa_response_generation_template, tags_generation_template, @@ -142,27 +126,17 @@ from open_webui.utils.task import ( ) from open_webui.utils.tools import get_tools from open_webui.utils.utils import ( - create_token, decode_token, get_admin_user, get_current_user, get_http_authorization_cred, - get_password_hash, get_verified_user, ) -from open_webui.utils.webhook import post_webhook - -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, -) if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -219,7 +193,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -693,6 +666,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) + ################################## # # Pipeline Middleware @@ -2314,20 +2288,6 @@ async def get_app_latest_release_version(): # OAuth Login & Callback ############################ -oauth = OAuth() - -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) - # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: app.add_middleware( @@ -2341,16 +2301,7 @@ if len(OAUTH_PROVIDERS) > 0: @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider - ) - client = oauth.create_client(provider) - if client is None: - raise HTTPException(404) - return await client.authorize_redirect(request, redirect_uri) + return await oauth_manager.handle_login(provider, request) # OAuth login logic is as follows: @@ -2361,118 +2312,7 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth.create_client(provider) - try: - token = await client.authorize_access_token(request) - except Exception as e: - log.warning(f"OAuth callback error: {e}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() - # We currently mandate that email addresses are provided - if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, check if merging is enabled - if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: - # Check if the user exists by email - user = Users.get_user_by_email(email) - if user: - # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) - - if not user: - # If the user does not exist, check if signups are enabled - if ENABLE_OAUTH_SIGNUP.value: - # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) - if existing_user: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - - picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") - if picture_url: - # Download the profile image into a base64 string - try: - async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: - picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) - guessed_mime_type = mimetypes.guess_type(picture_url)[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") - picture_url = "" - if not picture_url: - picture_url = "/user.png" - username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = ( - "admin" - if Users.get_num_users() == 0 - else webui_app.state.config.DEFAULT_USER_ROLE - ) - user = Auths.insert_new_auth( - email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get(username_claim, "User"), - profile_image_url=picture_url, - role=role, - oauth_sub=provider_sub, - ) - - if webui_app.state.config.WEBHOOK_URL: - post_webhook( - webui_app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - else: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN), - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=jwt_token, - httponly=True, # Ensures the cookie is not accessible via JavaScript - samesite=WEBUI_SESSION_COOKIE_SAME_SITE, - secure=WEBUI_SESSION_COOKIE_SECURE, - ) - - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return await oauth_manager.handle_callback(provider, request, response) @app.get("/manifest.json") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py new file mode 100644 index 000000000..d59e36733 --- /dev/null +++ b/backend/open_webui/utils/oauth.py @@ -0,0 +1,256 @@ +import base64 +import logging +import mimetypes +import uuid + +import aiohttp +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo +from fastapi import ( + HTTPException, + status, +) +from starlette.responses import RedirectResponse + +from open_webui.apps.webui.models.auths import Auths +from open_webui.apps.webui.models.users import Users +from open_webui.config import ( + DEFAULT_USER_ROLE, + ENABLE_OAUTH_SIGNUP, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PROVIDERS, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, + WEBHOOK_URL, + JWT_EXPIRES_IN, + AppConfig, +) +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE +from open_webui.utils.misc import parse_duration +from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.webhook import post_webhook + +log = logging.getLogger(__name__) + +auth_manager_config = AppConfig() +auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP +auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL +auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.WEBHOOK_URL = WEBHOOK_URL +auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN + + +class OAuthManager: + def __init__(self): + self.oauth = OAuth() + for provider_name, provider_config in OAUTH_PROVIDERS.items(): + self.oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + redirect_uri=provider_config["redirect_uri"], + ) + + def get_client(self, provider_name): + return self.oauth.create_client(provider_name) + + def get_user_role(self, user, user_data): + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" + + if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM + oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES + oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = auth_manager_config.DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + + async def handle_login(self, provider, request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + # If the provider has a custom redirect URL, use that, otherwise automatically generate one + redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( + "oauth_callback", provider=provider + ) + client = self.get_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) + + async def handle_callback(self, provider, request, response): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = self.get_client(provider) + try: + token = await client.authorize_access_token(request) + except Exception as e: + log.warning(f"OAuth callback error: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + log.warning(f"OAuth callback failed, sub is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM + email = user_data.get(email_claim, "").lower() + # We currently mandate that email addresses are provided + if not email: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, check if merging is enabled + if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: + # Check if the user exists by email + user = Users.get_user_by_email(email) + if user: + # Update the user with the new oauth sub + Users.update_user_oauth_sub_by_id(user.id, provider_sub) + + if user: + determined_role = self.get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) + + if not user: + # If the user does not exist, check if signups are enabled + if auth_manager_config.ENABLE_OAUTH_SIGNUP.value: + # Check if an existing user with the same email already exists + existing_user = Users.get_user_by_email( + user_data.get("email", "").lower() + ) + if existing_user: + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) + + picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") + if picture_url: + # Download the profile image into a base64 string + try: + async with aiohttp.ClientSession() as session: + async with session.get(picture_url) as resp: + picture = await resp.read() + base64_encoded_picture = base64.b64encode( + picture + ).decode("utf-8") + guessed_mime_type = mimetypes.guess_type(picture_url)[0] + if guessed_mime_type is None: + # assume JPG, browsers are tolerant enough of image formats + guessed_mime_type = "image/jpeg" + picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" + except Exception as e: + log.error( + f"Error downloading profile image '{picture_url}': {e}" + ) + picture_url = "" + if not picture_url: + picture_url = "/user.png" + username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM + + role = self.get_user_role(None, user_data) + + user = Auths.insert_new_auth( + email=email, + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get(username_claim, "User"), + profile_image_url=picture_url, + role=role, + oauth_sub=provider_sub, + ) + + if auth_manager_config.WEBHOOK_URL: + post_webhook( + auth_manager_config.WEBHOOK_URL, + auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP( + user.name + ), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), + ) + + # Set the cookie token + response.set_cookie( + key="token", + value=jwt_token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + samesite=WEBUI_SESSION_COOKIE_SAME_SITE, + secure=WEBUI_SESSION_COOKIE_SECURE, + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) + + +oauth_manager = OAuthManager()