Merge branch 'main' into dev

# Conflicts:
#	backend/open_webui/main.py
This commit is contained in:
Willnow, Patrick 2024-10-16 20:16:05 +02:00
commit b888ee17ff
5 changed files with 396 additions and 278 deletions

View File

@ -31,6 +31,7 @@ RUN npm ci
COPY . . COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH} ENV APP_BUILD_HASH=${BUILD_HASH}
ENV NODE_OPTIONS="--max_old_space_size=8192"
RUN npm run build RUN npm run build
######## WebUI backend ######## ######## WebUI backend ########

View File

@ -32,9 +32,13 @@ from open_webui.config import (
ENABLE_MESSAGE_RATING, ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP, ENABLE_SIGNUP,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM, OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
@ -93,6 +97,11 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {} app.state.FUNCTIONS = {}

View File

@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig(
) )
OAUTH_PICTURE_CLAIM = PersistentConfig( OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM", "OAUTH_PICTURE_CLAIM",
"oauth.oidc.avatar_claim", "oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
) )
@ -394,6 +394,29 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"), os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
) )
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
"OAUTH_ROLES_CLAIM",
"oauth.roles_claim",
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
)
OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles",
[role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")],
)
OAUTH_ADMIN_ROLES = PersistentConfig(
"OAUTH_ADMIN_ROLES",
"oauth.admin_roles",
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()

View File

@ -1,4 +1,4 @@
import base64 import asyncio
import inspect import inspect
import json import json
import logging import logging
@ -7,103 +7,11 @@ import os
import shutil import shutil
import sys import sys
import time import time
import uuid
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional
import aiohttp import aiohttp
import requests import requests
from open_webui.apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OAUTH_SIGNUP,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
run_migrations,
reset_config,
)
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from fastapi import ( from fastapi import (
Depends, Depends,
FastAPI, FastAPI,
@ -122,16 +30,92 @@ from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse, Response, StreamingResponse from starlette.responses import Response, StreamingResponse
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
generate_chat_completion as generate_ollama_chat_completion,
GenerateChatCompletionForm,
)
from open_webui.apps.openai.main import (
app as openai_app,
generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models,
)
from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
from open_webui.apps.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_call,
get_event_emitter,
)
from open_webui.apps.webui.internal.db import Session
from open_webui.apps.webui.main import (
app as webui_app,
generate_function_chat_completion,
get_pipe_models,
)
from open_webui.apps.webui.models.functions import Functions
from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
DEFAULT_LOCALE,
ENABLE_ADMIN_CHAT_ACCESS,
ENABLE_ADMIN_EXPORT,
ENABLE_MODEL_FILTER,
ENABLE_OLLAMA_API,
ENABLE_OPENAI_API,
ENV,
FRONTEND_BUILD_DIR,
MODEL_FILTER_LIST,
OAUTH_PROVIDERS,
ENABLE_SEARCH_QUERY,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
STATIC_DIR,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
WEBHOOK_URL,
WEBUI_AUTH,
WEBUI_NAME,
AppConfig,
reset_config,
)
from open_webui.constants import TASKS
from open_webui.env import (
CHANGELOG,
GLOBAL_LOG_LEVEL,
SAFE_MODE,
SRC_LOG_LEVELS,
VERSION,
WEBUI_BUILD_HASH,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
WEBUI_URL,
RESET_CONFIG_ON_START,
OFFLINE_MODE,
)
from open_webui.utils.misc import ( from open_webui.utils.misc import (
add_or_update_system_message, add_or_update_system_message,
get_last_user_message, get_last_user_message,
parse_duration,
prepend_to_first_user_message_content, prepend_to_first_user_message_content,
) )
from open_webui.utils.oauth import oauth_manager
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.task import ( from open_webui.utils.task import (
moa_response_generation_template, moa_response_generation_template,
search_query_generation_template, search_query_generation_template,
@ -140,27 +124,17 @@ from open_webui.utils.task import (
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.utils import ( from open_webui.utils.utils import (
create_token,
decode_token, decode_token,
get_admin_user, get_admin_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
get_password_hash,
get_verified_user, get_verified_user,
) )
from open_webui.utils.webhook import post_webhook
from open_webui.utils.payload import convert_payload_openai_to_ollama
from open_webui.utils.response import (
convert_response_ollama_to_openai,
convert_streaming_response_ollama_to_openai,
)
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
Functions.deactivate_all_functions() Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -217,7 +191,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@ -232,6 +205,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
app.state.MODELS = {} app.state.MODELS = {}
################################## ##################################
# #
# ChatCompletion Middleware # ChatCompletion Middleware
@ -245,14 +220,14 @@ def get_task_model_id(default_model_id):
# Check if the user has a custom task model and use that model # Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if ( if (
app.state.config.TASK_MODEL app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS and app.state.config.TASK_MODEL in app.state.MODELS
): ):
task_model_id = app.state.config.TASK_MODEL task_model_id = app.state.config.TASK_MODEL
else: else:
if ( if (
app.state.config.TASK_MODEL_EXTERNAL app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
): ):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL task_model_id = app.state.config.TASK_MODEL_EXTERNAL
@ -389,7 +364,7 @@ async def get_content_from_response(response) -> Optional[str]:
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
metadata = body.get("metadata", {}) metadata = body.get("metadata", {})
@ -690,6 +665,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
################################## ##################################
# #
# Pipeline Middleware # Pipeline Middleware
@ -702,15 +678,15 @@ def get_sorted_filters(model_id):
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
if "pipeline" in model if "pipeline" in model
and "type" in model["pipeline"] and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter" and model["pipeline"]["type"] == "filter"
and ( and (
model["pipeline"]["pipelines"] == ["*"] model["pipeline"]["pipelines"] == ["*"]
or any( or any(
model_id == target_model_id model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"] for target_model_id in model["pipeline"]["pipelines"]
) )
) )
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters return sorted_filters
@ -896,8 +872,8 @@ async def update_embedding_function(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def inspect_websocket(request: Request, call_next): async def inspect_websocket(request: Request, call_next):
if ( if (
"/ws/socket.io" in request.url.path "/ws/socket.io" in request.url.path
and request.query_params.get("transport") == "websocket" and request.query_params.get("transport") == "websocket"
): ):
upgrade = (request.headers.get("Upgrade") or "").lower() upgrade = (request.headers.get("Upgrade") or "").lower()
connection = (request.headers.get("Connection") or "").lower().split(",") connection = (request.headers.get("Connection") or "").lower().split(",")
@ -966,8 +942,8 @@ async def get_all_models():
if custom_model.base_model_id is None: if custom_model.base_model_id is None:
for model in models: for model in models:
if ( if (
custom_model.id == model["id"] custom_model.id == model["id"]
or custom_model.id == model["id"].split(":")[0] or custom_model.id == model["id"].split(":")[0]
): ):
model["name"] = custom_model.name model["name"] = custom_model.name
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
@ -984,8 +960,8 @@ async def get_all_models():
for model in models: for model in models:
if ( if (
custom_model.base_model_id == model["id"] custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0] or custom_model.base_model_id == model["id"].split(":")[0]
): ):
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model: if "pipe" in model:
@ -1785,7 +1761,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
@app.post("/api/pipelines/upload") @app.post("/api/pipelines/upload")
async def upload_pipeline( async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
): ):
print("upload_pipeline", urlIdx, file.filename) print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file # Check if the uploaded file is a python file
@ -1962,9 +1938,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves") @app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves( async def get_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@ -2000,9 +1976,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec") @app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec( async def get_pipeline_valves_spec(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@ -2037,10 +2013,10 @@ async def get_pipeline_valves_spec(
@app.post("/api/pipelines/{pipeline_id}/valves/update") @app.post("/api/pipelines/{pipeline_id}/valves/update")
async def update_pipeline_valves( async def update_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
form_data: dict, form_data: dict,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
r = None r = None
try: try:
@ -2164,7 +2140,7 @@ class ModelFilterConfigForm(BaseModel):
@app.post("/api/config/model/filter") @app.post("/api/config/model/filter")
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models app.state.config.MODEL_FILTER_LIST = form_data.models
@ -2219,7 +2195,7 @@ async def get_app_latest_release_version():
timeout = aiohttp.ClientTimeout(total=1) timeout = aiohttp.ClientTimeout(total=1)
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get( async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest" "https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
data = await response.json() data = await response.json()
@ -2235,20 +2211,6 @@ async def get_app_latest_release_version():
# OAuth Login & Callback # OAuth Login & Callback
############################ ############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
# SessionMiddleware is used by authlib for oauth # SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0: if len(OAUTH_PROVIDERS) > 0:
app.add_middleware( app.add_middleware(
@ -2262,16 +2224,7 @@ if len(OAUTH_PROVIDERS) > 0:
@app.get("/oauth/{provider}/login") @app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS: return await oauth_manager.handle_login(provider, request)
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = oauth.create_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows: # OAuth login logic is as follows:
@ -2282,118 +2235,7 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken # - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/callback") @app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response): async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS: return await oauth_manager.handle_callback(provider, request, response)
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
else webui_app.state.config.DEFAULT_USER_ROLE
)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
secure=WEBUI_SESSION_COOKIE_SECURE,
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json") @app.get("/manifest.json")

View File

@ -0,0 +1,243 @@
import base64
import logging
import mimetypes
import uuid
import aiohttp
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import (
HTTPException,
status,
)
from starlette.responses import RedirectResponse
from open_webui.apps.webui.models.auths import Auths
from open_webui.apps.webui.models.users import Users
from open_webui.config import (
DEFAULT_USER_ROLE,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
OAUTH_PROVIDERS,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_USERNAME_CLAIM,
OAUTH_ALLOWED_ROLES,
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.utils.misc import parse_duration
from open_webui.utils.utils import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook
log = logging.getLogger(__name__)
auth_manager_config = AppConfig()
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
class OAuthManager:
def __init__(self):
self.oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
self.oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
redirect_uri=provider_config["redirect_uri"],
)
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
role = "admin"
break
else:
if not user:
# If role management is disabled, use the default role for new users
role = auth_manager_config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
async def handle_login(self, provider, request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
client = self.get_client(provider)
if client is None:
raise HTTPException(404)
return await client.authorize_redirect(request, redirect_uri)
async def handle_callback(self, provider, request, response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = self.get_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
email = user_data.get(email_claim, "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
role = self.get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
)
if auth_manager_config.WEBHOOK_URL:
post_webhook(
auth_manager_config.WEBHOOK_URL,
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=jwt_token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
oauth_manager = OAuthManager()