mirror of
https://github.com/open-webui/open-webui
synced 2025-03-23 22:31:38 +00:00
Finish reorganizing oauth code
This commit is contained in:
parent
08ff494754
commit
8eebd6bce1
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import inspect
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -7,89 +8,11 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from open_webui.apps.audio.main import app as audio_app
|
||||
from open_webui.apps.images.main import app as images_app
|
||||
from open_webui.apps.ollama.main import app as ollama_app
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateChatCompletionForm,
|
||||
generate_chat_completion as generate_ollama_chat_completion,
|
||||
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
|
||||
)
|
||||
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
|
||||
from open_webui.apps.openai.main import app as openai_app
|
||||
from open_webui.apps.openai.main import (
|
||||
generate_chat_completion as generate_openai_chat_completion,
|
||||
)
|
||||
from open_webui.apps.openai.main import get_all_models as get_openai_models
|
||||
from open_webui.apps.rag.main import app as rag_app
|
||||
from open_webui.apps.rag.utils import get_rag_context, rag_template
|
||||
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
|
||||
from open_webui.apps.socket.main import get_event_call, get_event_emitter
|
||||
from open_webui.apps.webui.internal.db import Session
|
||||
from open_webui.apps.webui.main import app as webui_app
|
||||
from open_webui.apps.webui.main import (
|
||||
generate_function_chat_completion,
|
||||
get_pipe_models,
|
||||
)
|
||||
from open_webui.apps.webui.models.auths import Auths
|
||||
from open_webui.apps.webui.models.functions import Functions
|
||||
from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.models.users import UserModel, Users
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id
|
||||
|
||||
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
DEFAULT_LOCALE,
|
||||
ENABLE_ADMIN_CHAT_ACCESS,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
ENABLE_MODEL_FILTER,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
ENABLE_OLLAMA_API,
|
||||
ENABLE_OPENAI_API,
|
||||
ENV,
|
||||
FRONTEND_BUILD_DIR,
|
||||
MODEL_FILTER_LIST,
|
||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_SEARCH_QUERY,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
STATIC_DIR,
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH,
|
||||
WEBUI_NAME,
|
||||
AppConfig,
|
||||
run_migrations,
|
||||
reset_config,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
CHANGELOG,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_URL,
|
||||
RESET_CONFIG_ON_START,
|
||||
)
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
@ -108,16 +31,88 @@ from sqlalchemy import text
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import RedirectResponse, Response, StreamingResponse
|
||||
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from open_webui.apps.audio.main import app as audio_app
|
||||
from open_webui.apps.images.main import app as images_app
|
||||
from open_webui.apps.ollama.main import (
|
||||
GenerateChatCompletionForm,
|
||||
generate_chat_completion as generate_ollama_chat_completion,
|
||||
)
|
||||
from open_webui.apps.ollama.main import app as ollama_app
|
||||
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
|
||||
from open_webui.apps.openai.main import app as openai_app
|
||||
from open_webui.apps.openai.main import (
|
||||
generate_chat_completion as generate_openai_chat_completion,
|
||||
)
|
||||
from open_webui.apps.openai.main import get_all_models as get_openai_models
|
||||
from open_webui.apps.rag.main import app as rag_app
|
||||
from open_webui.apps.rag.utils import get_rag_context, rag_template
|
||||
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
|
||||
from open_webui.apps.socket.main import get_event_call, get_event_emitter
|
||||
from open_webui.apps.webui.internal.db import Session
|
||||
from open_webui.apps.webui.main import app as webui_app
|
||||
from open_webui.apps.webui.main import (
|
||||
generate_function_chat_completion,
|
||||
get_pipe_models,
|
||||
)
|
||||
from open_webui.apps.webui.models.functions import Functions
|
||||
from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.models.users import UserModel, Users
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id
|
||||
from open_webui.config import (
|
||||
CACHE_DIR,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
DEFAULT_LOCALE,
|
||||
ENABLE_ADMIN_CHAT_ACCESS,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
ENABLE_MODEL_FILTER,
|
||||
ENABLE_OLLAMA_API,
|
||||
ENABLE_OPENAI_API,
|
||||
ENV,
|
||||
FRONTEND_BUILD_DIR,
|
||||
MODEL_FILTER_LIST,
|
||||
OAUTH_PROVIDERS,
|
||||
ENABLE_SEARCH_QUERY,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
STATIC_DIR,
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH,
|
||||
WEBUI_NAME,
|
||||
AppConfig,
|
||||
run_migrations,
|
||||
reset_config,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES, TASKS
|
||||
from open_webui.env import (
|
||||
CHANGELOG,
|
||||
GLOBAL_LOG_LEVEL,
|
||||
SAFE_MODE,
|
||||
SRC_LOG_LEVELS,
|
||||
VERSION,
|
||||
WEBUI_BUILD_HASH,
|
||||
WEBUI_SECRET_KEY,
|
||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||
WEBUI_SESSION_COOKIE_SECURE,
|
||||
WEBUI_URL,
|
||||
RESET_CONFIG_ON_START,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
parse_duration,
|
||||
prepend_to_first_user_message_content,
|
||||
)
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.task import (
|
||||
moa_response_generation_template,
|
||||
search_query_generation_template,
|
||||
@ -126,23 +121,12 @@ from open_webui.utils.task import (
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.utils import (
|
||||
create_token,
|
||||
decode_token,
|
||||
get_admin_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
get_password_hash,
|
||||
get_verified_user,
|
||||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||
from open_webui.utils.response import (
|
||||
convert_response_ollama_to_openai,
|
||||
convert_streaming_response_ollama_to_openai,
|
||||
)
|
||||
|
||||
from open_webui.utils.oauth import oauth_manager
|
||||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
@ -220,6 +204,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# ChatCompletion Middleware
|
||||
@ -2181,7 +2167,7 @@ if len(OAUTH_PROVIDERS) > 0:
|
||||
|
||||
@app.get("/oauth/{provider}/login")
|
||||
async def oauth_login(provider: str, request: Request):
|
||||
return oauth_manager.handle_login(provider, request)
|
||||
return await oauth_manager.handle_login(provider, request)
|
||||
|
||||
|
||||
# OAuth login logic is as follows:
|
||||
@ -2192,7 +2178,7 @@ async def oauth_login(provider: str, request: Request):
|
||||
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
|
||||
@app.get("/oauth/{provider}/callback")
|
||||
async def oauth_callback(provider: str, request: Request, response: Response):
|
||||
return oauth_manager.handle_callback(provider, request, response)
|
||||
return await oauth_manager.handle_callback(provider, request, response)
|
||||
|
||||
|
||||
@app.get("/manifest.json")
|
||||
|
@ -1,19 +1,19 @@
|
||||
import base64
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import logging
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
from authlib.oidc.core import UserInfo
|
||||
from fastapi import (
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from starlette.responses import RedirectResponse, Response, StreamingResponse
|
||||
from authlib.oidc.core import UserInfo
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from open_webui.apps.webui.models.auths import Auths
|
||||
from open_webui.apps.webui.models.users import Users, UserModel
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.config import (
|
||||
DEFAULT_USER_ROLE,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
@ -25,210 +25,219 @@ from open_webui.config import (
|
||||
OAUTH_PICTURE_CLAIM,
|
||||
OAUTH_USERNAME_CLAIM,
|
||||
OAUTH_ALLOWED_ROLES,
|
||||
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN,
|
||||
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
|
||||
)
|
||||
|
||||
from authlib.integrations.starlette_client import OAuth
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.misc import parse_duration
|
||||
from open_webui.utils.utils import get_password_hash, create_token
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
oauth_manager = {}
|
||||
oauth_manager.oauth = OAuth()
|
||||
auth_manager_config = AppConfig()
|
||||
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
|
||||
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
|
||||
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
||||
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
||||
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
||||
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
|
||||
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
|
||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||
oauth_manager.oauth.register(
|
||||
name=provider_name,
|
||||
client_id=provider_config["client_id"],
|
||||
client_secret=provider_config["client_secret"],
|
||||
server_metadata_url=provider_config["server_metadata_url"],
|
||||
client_kwargs={
|
||||
"scope": provider_config["scope"],
|
||||
},
|
||||
redirect_uri=provider_config["redirect_uri"],
|
||||
)
|
||||
|
||||
oauth_manager.get_client = oauth_manager.oauth.create_client
|
||||
class OAuthManager:
|
||||
def __init__(self):
|
||||
self.oauth = OAuth()
|
||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||
self.oauth.register(
|
||||
name=provider_name,
|
||||
client_id=provider_config["client_id"],
|
||||
client_secret=provider_config["client_secret"],
|
||||
server_metadata_url=provider_config["server_metadata_url"],
|
||||
client_kwargs={
|
||||
"scope": provider_config["scope"],
|
||||
},
|
||||
redirect_uri=provider_config["redirect_uri"],
|
||||
)
|
||||
|
||||
def get_user_role(user: UserModel, user_data: UserInfo) -> str:
|
||||
if user and Users.get_num_users() == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
return "admin"
|
||||
def get_client(self, provider_name):
|
||||
return self.oauth.create_client(provider_name)
|
||||
|
||||
if ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||
oauth_claim = OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = OAUTH_ADMIN_ROLES
|
||||
oauth_roles = None
|
||||
role = "pending" # Default/fallback role if no matching roles are found
|
||||
def get_user_role(self, user, user_data):
|
||||
if user and Users.get_num_users() == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
return "admin"
|
||||
if not user and Users.get_num_users() == 0:
|
||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||
return "admin"
|
||||
|
||||
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
||||
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
||||
claim_data = user_data
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||
oauth_roles = None
|
||||
role = "pending" # Default/fallback role if no matching roles are found
|
||||
|
||||
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
||||
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
||||
claim_data = user_data
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||
|
||||
# If any roles are found, check if they match the allowed or admin roles
|
||||
if oauth_roles:
|
||||
# If role management is enabled, and matching roles are provided, use the roles
|
||||
for allowed_role in oauth_allowed_roles:
|
||||
# If the user has any of the allowed roles, assign the role "user"
|
||||
if allowed_role in oauth_roles:
|
||||
role = "user"
|
||||
break
|
||||
for admin_role in oauth_admin_roles:
|
||||
# If the user has any of the admin roles, assign the role "admin"
|
||||
if admin_role in oauth_roles:
|
||||
role = "admin"
|
||||
break
|
||||
else:
|
||||
if not user:
|
||||
# If role management is disabled, use the default role for new users
|
||||
role = auth_manager_config.DEFAULT_USER_ROLE
|
||||
else:
|
||||
# If role management is disabled, use the existing role for existing users
|
||||
role = user.role
|
||||
|
||||
return role
|
||||
|
||||
async def handle_login(self, provider, request):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||
"oauth_callback", provider=provider
|
||||
)
|
||||
client = self.get_client(provider)
|
||||
if client is None:
|
||||
raise HTTPException(404)
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
|
||||
async def handle_callback(self, provider, request, response):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
client = self.get_client(provider)
|
||||
try:
|
||||
token = await client.authorize_access_token(request)
|
||||
except Exception as e:
|
||||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token["userinfo"]
|
||||
|
||||
sub = user_data.get("sub")
|
||||
if not sub:
|
||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||
email = user_data.get(email_claim, "").lower()
|
||||
# We currently mandate that email addresses are provided
|
||||
if not email:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||
|
||||
# If any roles are found, check if they match the allowed or admin roles
|
||||
if oauth_roles:
|
||||
# If role management is enabled, and matching roles are provided, use the roles
|
||||
for allowed_role in oauth_allowed_roles:
|
||||
# If the user has any of the allowed roles, assign the role "user"
|
||||
if allowed_role in oauth_roles:
|
||||
role = "user"
|
||||
break
|
||||
for admin_role in oauth_admin_roles:
|
||||
# If the user has any of the admin roles, assign the role "admin"
|
||||
if admin_role in oauth_roles:
|
||||
role = "admin"
|
||||
break
|
||||
else:
|
||||
if not user:
|
||||
# If role management is disabled, use the default role for new users
|
||||
role = DEFAULT_USER_ROLE
|
||||
else:
|
||||
# If role management is disabled, use the existing role for existing users
|
||||
role = user.role
|
||||
# If the user does not exist, check if merging is enabled
|
||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
||||
# Check if the user exists by email
|
||||
user = Users.get_user_by_email(email)
|
||||
if user:
|
||||
# Update the user with the new oauth sub
|
||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||
|
||||
return role
|
||||
if user:
|
||||
determined_role = self.get_user_role(user, user_data)
|
||||
if user.role != determined_role:
|
||||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
|
||||
oauth_manager.get_user_role = get_user_role
|
||||
if not user:
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
async def handle_login(provider: str, request: Request):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||
"oauth_callback", provider=provider
|
||||
)
|
||||
client = oauth_manager.get_client(provider)
|
||||
if client is None:
|
||||
raise HTTPException(404)
|
||||
return await client.authorize_redirect(request, redirect_uri)
|
||||
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
||||
picture_url = user_data.get(picture_claim, "")
|
||||
if picture_url:
|
||||
# Download the profile image into a base64 string
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||
"utf-8"
|
||||
)
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
except Exception as e:
|
||||
log.error(f"Error downloading profile image '{picture_url}': {e}")
|
||||
picture_url = ""
|
||||
if not picture_url:
|
||||
picture_url = "/user.png"
|
||||
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||
|
||||
oauth_manager.handle_login = handle_login
|
||||
role = self.get_user_role(None, user_data)
|
||||
|
||||
async def handle_callback(provider: str, request: Request, response: Response):
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise HTTPException(404)
|
||||
client = oauth_manager.get_client(provider)
|
||||
try:
|
||||
token = await client.authorize_access_token(request)
|
||||
except Exception as e:
|
||||
log.warning(f"OAuth callback error: {e}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
user_data: UserInfo = token["userinfo"]
|
||||
|
||||
sub = user_data.get("sub")
|
||||
if not sub:
|
||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
provider_sub = f"{provider}@{sub}"
|
||||
email_claim = OAUTH_EMAIL_CLAIM
|
||||
email = user_data.get(email_claim, "").lower()
|
||||
# We currently mandate that email addresses are provided
|
||||
if not email:
|
||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||
|
||||
if not user:
|
||||
# If the user does not exist, check if merging is enabled
|
||||
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
||||
# Check if the user exists by email
|
||||
user = Users.get_user_by_email(email)
|
||||
if user:
|
||||
# Update the user with the new oauth sub
|
||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||
|
||||
if user:
|
||||
determined_role = get_user_role(user, user_data)
|
||||
if user.role != determined_role:
|
||||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
|
||||
if not user:
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
picture_claim = OAUTH_PICTURE_CLAIM
|
||||
picture_url = user_data.get(picture_claim, "")
|
||||
if picture_url:
|
||||
# Download the profile image into a base64 string
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||
"utf-8"
|
||||
)
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
except Exception as e:
|
||||
log.error(f"Error downloading profile image '{picture_url}': {e}")
|
||||
picture_url = ""
|
||||
if not picture_url:
|
||||
picture_url = "/user.png"
|
||||
username_claim = OAUTH_USERNAME_CLAIM
|
||||
|
||||
role = get_user_role(None, user_data)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
email=email,
|
||||
password=get_password_hash(
|
||||
str(uuid.uuid4())
|
||||
), # Random password, not used
|
||||
name=user_data.get(username_claim, "User"),
|
||||
profile_image_url=picture_url,
|
||||
role=role,
|
||||
oauth_sub=provider_sub,
|
||||
)
|
||||
|
||||
if WEBHOOK_URL:
|
||||
post_webhook(
|
||||
WEBHOOK_URL,
|
||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
"action": "signup",
|
||||
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
"user": user.model_dump_json(exclude_none=True),
|
||||
},
|
||||
user = Auths.insert_new_auth(
|
||||
email=email,
|
||||
password=get_password_hash(
|
||||
str(uuid.uuid4())
|
||||
), # Random password, not used
|
||||
name=user_data.get(username_claim, "User"),
|
||||
profile_image_url=picture_url,
|
||||
role=role,
|
||||
oauth_sub=provider_sub,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
jwt_token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(JWT_EXPIRES_IN),
|
||||
)
|
||||
if auth_manager_config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
auth_manager_config.WEBHOOK_URL,
|
||||
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
"action": "signup",
|
||||
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
"user": user.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=jwt_token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
jwt_token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
||||
)
|
||||
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=jwt_token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
oauth_manager.handle_callback = handle_callback
|
||||
# Redirect back to the frontend with the JWT token
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
oauth_manager = OAuthManager()
|
Loading…
Reference in New Issue
Block a user