Finish reorganizing oauth code

This commit is contained in:
Willnow, Patrick 2024-10-16 16:32:57 +02:00
parent 08ff494754
commit 8eebd6bce1
2 changed files with 279 additions and 284 deletions

View File

@ -1,4 +1,5 @@
import base64 import inspect
import asyncio
import inspect import inspect
import json import json
import logging import logging
@ -7,89 +8,11 @@ import os
import shutil import shutil
import sys import sys
import time import time
import uuid
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional
import aiohttp import aiohttp
import requests import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
)
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.webui.main import (
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
)
from fastapi import ( from fastapi import (
Depends, Depends,
FastAPI, FastAPI,
@ -108,16 +31,88 @@ from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import (
GenerateChatCompletionForm,
generate_chat_completion as generate_ollama_chat_completion,
)
from open_webui.apps.ollama.main import app as ollama_app
from open_webui.apps.ollama.main import get_all_models as get_ollama_models
from open_webui.apps.openai.main import app as openai_app
from open_webui.apps.openai.main import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.apps.openai.main import get_all_models as get_openai_models
from open_webui.apps.rag.main import app as rag_app
from open_webui.apps.rag.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup
from open_webui.apps.socket.main import get_event_call, get_event_emitter
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import app as webui_app
from open_webui.apps.webui.main import (
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
)
from open_webui.utils.misc import ( from open_webui.utils.misc import (
add_or_update_system_message, add_or_update_system_message,
get_last_user_message, get_last_user_message,
parse_duration,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
) )
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import ( from open_webui.utils.task import (
moa_response_generation_template, moa_response_generation_template,
search_query_generation_template, search_query_generation_template,
@ -126,23 +121,12 @@ from open_webui.utils.task import (
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.utils import ( from open_webui.utils.utils import (
create_token,
decode_token, decode_token,
get_admin_user, get_admin_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
get_password_hash,
get_verified_user, get_verified_user,
) )
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.oauth import oauth_manager
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
@ -220,6 +204,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {} app.state.MODELS = {}
################################## ##################################
# #
# ChatCompletion Middleware # ChatCompletion Middleware
@ -2181,7 +2167,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login") @app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
return oauth_manager.handle_login(provider, request) return await oauth_manager.handle_login(provider, request)
# OAuth login logic is as follows: # OAuth login logic is as follows:
@ -2192,7 +2178,7 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken # - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
@app.get("/oauth/{provider}/callback") @app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response): async def oauth_callback(provider: str, request: Request, response: Response):
return oauth_manager.handle_callback(provider, request, response) return await oauth_manager.handle_callback(provider, request, response)
@app.get("/manifest.json") @app.get("/manifest.json")

View File

@ -1,19 +1,19 @@
import base64 import base64
import logging
import mimetypes import mimetypes
import uuid import uuid
import aiohttp import aiohttp
import logging from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import ( from fastapi import (
HTTPException, HTTPException,
Request,
status, status,
) )
from starlette.responses import RedirectResponse, Response, StreamingResponse from starlette.responses import RedirectResponse
from authlib.oidc.core import UserInfo
from open_webui.apps.webui.models.auths import Auths from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users, UserModel from open_webui.apps.webui.models.users import Users
from open_webui.config import ( from open_webui.config import (
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP, ENABLE_OAUTH_SIGNUP,
@ -25,23 +25,35 @@ from open_webui.config import (
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES, OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
) )
from open_webui.constants import ERROR_MESSAGES
from authlib.integrations.starlette_client import OAuth
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from open_webui.utils.misc import parse_duration from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
oauth_manager = {} auth_manager_config = AppConfig()
oauth_manager.oauth = OAuth() auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth_manager.oauth.register( class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name, name=provider_name,
client_id=provider_config["client_id"], client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"], client_secret=provider_config["client_secret"],
@ -52,9 +64,10 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items():
redirect_uri=provider_config["redirect_uri"], redirect_uri=provider_config["redirect_uri"],
) )
oauth_manager.get_client = oauth_manager.oauth.create_client def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(user: UserModel, user_data: UserInfo) -> str: def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1: if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin" return "admin"
@ -62,10 +75,10 @@ def get_user_role(user: UserModel, user_data: UserInfo) -> str:
# If there are no users, assign the role "admin", as the first user will be an admin # If there are no users, assign the role "admin", as the first user will be an admin
return "admin" return "admin"
if ENABLE_OAUTH_ROLE_MANAGEMENT: if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = OAUTH_ROLES_CLAIM oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = OAUTH_ALLOWED_ROLES oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = OAUTH_ADMIN_ROLES oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found role = "pending" # Default/fallback role if no matching roles are found
@ -93,33 +106,29 @@ def get_user_role(user: UserModel, user_data: UserInfo) -> str:
else: else:
if not user: if not user:
# If role management is disabled, use the default role for new users # If role management is disabled, use the default role for new users
role = DEFAULT_USER_ROLE role = auth_manager_config.DEFAULT_USER_ROLE
else: else:
# If role management is disabled, use the existing role for existing users # If role management is disabled, use the existing role for existing users
role = user.role role = user.role
return role return role
oauth_manager.get_user_role = get_user_role async def handle_login(self, provider, request):
async def handle_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one # If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider "oauth_callback", provider=provider
) )
client = oauth_manager.get_client(provider) client = self.get_client(provider)
if client is None: if client is None:
raise HTTPException(404) raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri) return await client.authorize_redirect(request, redirect_uri)
oauth_manager.handle_login = handle_login async def handle_callback(self, provider, request, response):
async def handle_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
client = oauth_manager.get_client(provider) client = self.get_client(provider)
try: try:
token = await client.authorize_access_token(request) token = await client.authorize_access_token(request)
except Exception as e: except Exception as e:
@ -132,7 +141,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
log.warning(f"OAuth callback failed, sub is missing: {user_data}") log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}" provider_sub = f"{provider}@{sub}"
email_claim = OAUTH_EMAIL_CLAIM email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower() email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided # We currently mandate that email addresses are provided
if not email: if not email:
@ -144,7 +153,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
if not user: if not user:
# If the user does not exist, check if merging is enabled # If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email # Check if the user exists by email
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
if user: if user:
@ -152,19 +161,19 @@ async def handle_callback(provider: str, request: Request, response: Response):
Users.update_user_oauth_sub_by_id(user.id, provider_sub) Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user: if user:
determined_role = get_user_role(user, user_data) determined_role = self.get_user_role(user, user_data)
if user.role != determined_role: if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role) Users.update_user_role_by_id(user.id, determined_role)
if not user: if not user:
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value: if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = OAUTH_PICTURE_CLAIM picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "") picture_url = user_data.get(picture_claim, "")
if picture_url: if picture_url:
# Download the profile image into a base64 string # Download the profile image into a base64 string
@ -185,9 +194,9 @@ async def handle_callback(provider: str, request: Request, response: Response):
picture_url = "" picture_url = ""
if not picture_url: if not picture_url:
picture_url = "/user.png" picture_url = "/user.png"
username_claim = OAUTH_USERNAME_CLAIM username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = get_user_role(None, user_data) role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
email=email, email=email,
@ -200,13 +209,13 @@ async def handle_callback(provider: str, request: Request, response: Response):
oauth_sub=provider_sub, oauth_sub=provider_sub,
) )
if WEBHOOK_URL: if auth_manager_config.WEBHOOK_URL:
post_webhook( post_webhook(
WEBHOOK_URL, auth_manager_config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True), "user": user.model_dump_json(exclude_none=True),
}, },
) )
@ -217,7 +226,7 @@ async def handle_callback(provider: str, request: Request, response: Response):
jwt_token = create_token( jwt_token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(JWT_EXPIRES_IN), expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
) )
# Set the cookie token # Set the cookie token
@ -231,4 +240,4 @@ async def handle_callback(provider: str, request: Request, response: Response):
redirect_url = f"{request.base_url}auth#token={jwt_token}" redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url) return RedirectResponse(url=redirect_url)
oauth_manager.handle_callback = handle_callback oauth_manager = OAuthManager()