Merge pull request #6238 from Cyb4Black/dev

feat: oauth based role management
This commit is contained in:
Timothy Jaeryang Baek 2024-10-20 18:40:21 -07:00 committed by GitHub
commit c023afac8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 407 additions and 274 deletions

View File

@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel):
class GenerateEmbedForm(BaseModel):
model: str
input: list[str]|str
input: list[str] | str
truncate: Optional[bool] = None
options: Optional[dict] = None
keep_alive: Optional[Union[int, str]] = None

View File

@ -110,9 +110,8 @@ class ChromaClient:
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
@ -131,9 +130,8 @@ class ChromaClient:
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]

View File

@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI
NO_LIMIT = 999999999
class QdrantClient:
def __init__(self):
self.collection_prefix = "open-webui"
@ -38,15 +39,15 @@ class QdrantClient:
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
vectors_config=models.VectorParams(
size=dimension, distance=models.Distance.COSINE
),
)
print(f"collection {collection_name_with_prefix} successfully created!")
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(
collection_name=collection_name
):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
@ -56,22 +57,23 @@ class QdrantClient:
PointStruct(
id=item["id"],
vector=item["vector"],
payload={
"text": item["text"],
"metadata": item["metadata"]
},
payload={"text": item["text"], "metadata": item["metadata"]},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str):
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:
@ -87,7 +89,7 @@ class QdrantClient:
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]]
distances=[[point.score for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
@ -101,7 +103,10 @@ class QdrantClient:
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
@ -117,7 +122,7 @@ class QdrantClient:
# Get all the items in the collection.
points = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points.points)
@ -134,10 +139,10 @@ class QdrantClient:
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids.
field_conditions = []
@ -162,9 +167,7 @@ class QdrantClient:
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(
must=field_conditions
)
filter=models.Filter(must=field_conditions)
),
)

View File

@ -33,9 +33,13 @@ from open_webui.config import (
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS,
USER_PERMISSIONS,
WEBHOOK_URL,
@ -94,6 +98,11 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}

View File

@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig(
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"OAUTH_PICTURE_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
@ -394,6 +394,33 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
)
OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles",
[
role.strip()
for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
],
)
OAUTH_ADMIN_ROLES = PersistentConfig(
"OAUTH_ADMIN_ROLES",
"oauth.admin_roles",
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()

View File

@ -1,4 +1,4 @@
import base64
import asyncio
import inspect
import json
import logging
@ -7,104 +7,11 @@ import os
import shutil
import sys
import time
import uuid
import asyncio
from contextlib import asynccontextmanager
from typing import Optional
import aiohttp
import requests
from open_webui.apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from fastapi import (
Depends,
FastAPI,
@ -123,16 +30,93 @@ from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TAGS_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
reset_config,
)
from open_webui.constants import TASKS
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
parse_duration,
prepend_to_first_user_message_content,
)
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
tags_generation_template,
@ -142,27 +126,17 @@ from open_webui.utils.task import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
create_token,
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_password_hash,
get_verified_user,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -219,7 +193,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@ -693,6 +666,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
@ -2314,20 +2288,6 @@ async def get_app_latest_release_version():
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
@ -2341,16 +2301,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
return await oauth_manager.handle_login(provider, request)
# OAuth login logic is as follows:
@ -2361,118 +2312,7 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
else webui_app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
return await oauth_manager.handle_callback(provider, request, response)
@app.get("/manifest.json")

View File

@ -0,0 +1,256 @@
import base64
import logging
import mimetypes
import uuid
import aiohttp
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
status,
)
from starlette.responses import RedirectResponse
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
WEBHOOK_URL,
JWT_EXPIRES_IN,
AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
role = "admin"
break
else:
if not user:
# If role management is disabled, use the default role for new users
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(
user_data.get("email", "").lower()
)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(
picture
).decode("utf-8")
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(
f"Error downloading profile image '{picture_url}': {e}"
)
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if auth_manager_config.WEBHOOK_URL:
post_webhook(
auth_manager_config.WEBHOOK_URL,
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
user.name
),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
oauth_manager = OAuthManager()