Finish reorganizing oauth code

This commit is contained in:
Willnow, Patrick 2024-10-16 16:32:57 +02:00
parent 08ff494754
commit 8eebd6bce1
2 changed files with 279 additions and 284 deletions

View File

@ -1,4 +1,5 @@
import base64
import inspect
import asyncio
import inspect
import json
import logging
@ -7,89 +8,11 @@ import os
import shutil
import sys
import time
import uuid
import asyncio
from contextlib import asynccontextmanager
from typing import Optional
import aiohttp
import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
)
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.webui.main import (
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
)
from fastapi import (
Depends,
FastAPI,
@ -108,16 +31,88 @@ from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from starlette.responses import Response, StreamingResponse
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
generate_chat_completion as generate_ollama_chat_completion,
)
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.webui.main import (
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
)
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
parse_duration,
prepend_to_first_user_message_content,
)
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import (
moa_response_generation_template,
search_query_generation_template,
@ -126,23 +121,12 @@ from open_webui.utils.task import (
)
from open_webui.utils.tools import get_tools
from open_webui.utils.utils import (
create_token,
decode_token,
get_admin_user,
get_current_user,
get_http_authorization_cred,
get_password_hash,
get_verified_user,
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.oauth import oauth_manager
if SAFE_MODE:
print("SAFE MODE ENABLED")
@ -220,6 +204,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {}
##################################
#
# ChatCompletion Middleware
@ -2181,7 +2167,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
return oauth_manager.handle_login(provider, request)
return await oauth_manager.handle_login(provider, request)
# OAuth login logic is as follows:
@ -2192,7 +2178,7 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
return oauth_manager.handle_callback(provider, request, response)
return await oauth_manager.handle_callback(provider, request, response)
@app.get("/manifest.json")

View File

@ -1,19 +1,19 @@
import base64
import logging
import mimetypes
import uuid
import aiohttp
import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
Request,
status,
)
from starlette.responses import RedirectResponse, Response, StreamingResponse
from authlib.oidc.core import UserInfo
from starlette.responses import RedirectResponse
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users, UserModel
from open_webui.apps.webui.models.users import Users
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
@ -25,23 +25,35 @@ from open_webui.config import (
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN,
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
)
from authlib.integrations.starlette_client import OAuth
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
oauth_manager = {}
oauth_manager.oauth = OAuth()
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth_manager.oauth.register(
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
@ -52,9 +64,10 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items():
redirect_uri=provider_config["redirect_uri"],
)
oauth_manager.get_client = oauth_manager.oauth.create_client
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(user: UserModel, user_data: UserInfo) -> str:
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
@ -62,10 +75,10 @@ def get_user_role(user: UserModel, user_data: UserInfo) -> str:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = OAUTH_ROLES_CLAIM
oauth_allowed_roles = OAUTH_ALLOWED_ROLES
oauth_admin_roles = OAUTH_ADMIN_ROLES
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
@ -93,33 +106,29 @@ def get_user_role(user: UserModel, user_data: UserInfo) -> str:
else:
if not user:
# If role management is disabled, use the default role for new users
role = DEFAULT_USER_ROLE
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
oauth_manager.get_user_role = get_user_role
async def handle_login(provider: str, request: Request):
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = oauth_manager.get_client(provider)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
oauth_manager.handle_login = handle_login
async def handle_callback(provider: str, request: Request, response: Response):
async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth_manager.get_client(provider)
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
@ -132,7 +141,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = OAUTH_EMAIL_CLAIM
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
@ -144,7 +153,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
@ -152,19 +161,19 @@ async def handle_callback(provider: str, request: Request, response: Response):
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = get_user_role(user, user_data)
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = OAUTH_PICTURE_CLAIM
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
@ -185,9 +194,9 @@ async def handle_callback(provider: str, request: Request, response: Response):
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = OAUTH_USERNAME_CLAIM
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = get_user_role(None, user_data)
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,
@ -200,13 +209,13 @@ async def handle_callback(provider: str, request: Request, response: Response):
oauth_sub=provider_sub,
)
if WEBHOOK_URL:
if auth_manager_config.WEBHOOK_URL:
post_webhook(
WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
auth_manager_config.WEBHOOK_URL,
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
@ -217,7 +226,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(JWT_EXPIRES_IN),
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
# Set the cookie token
@ -231,4 +240,4 @@ async def handle_callback(provider: str, request: Request, response: Response):
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
oauth_manager.handle_callback = handle_callback
oauth_manager = OAuthManager()