mirror of
https://github.com/open-webui/open-webui
synced 2025-04-08 06:35:04 +00:00
Merge branch 'main' into dev
# Conflicts: # backend/open_webui/main.py
This commit is contained in:
commit
b888ee17ff
@ -31,6 +31,7 @@ RUN npm ci
|
|||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
ENV APP_BUILD_HASH=${BUILD_HASH}
|
ENV APP_BUILD_HASH=${BUILD_HASH}
|
||||||
|
ENV NODE_OPTIONS="--max_old_space_size=8192"
|
||||||
RUN npm run build
|
RUN npm run build
|
||||||
|
|
||||||
######## WebUI backend ########
|
######## WebUI backend ########
|
||||||
|
@ -32,9 +32,13 @@ from open_webui.config import (
|
|||||||
ENABLE_MESSAGE_RATING,
|
ENABLE_MESSAGE_RATING,
|
||||||
ENABLE_SIGNUP,
|
ENABLE_SIGNUP,
|
||||||
JWT_EXPIRES_IN,
|
JWT_EXPIRES_IN,
|
||||||
|
ENABLE_OAUTH_ROLE_MANAGEMENT,
|
||||||
|
OAUTH_ROLES_CLAIM,
|
||||||
OAUTH_EMAIL_CLAIM,
|
OAUTH_EMAIL_CLAIM,
|
||||||
OAUTH_PICTURE_CLAIM,
|
OAUTH_PICTURE_CLAIM,
|
||||||
OAUTH_USERNAME_CLAIM,
|
OAUTH_USERNAME_CLAIM,
|
||||||
|
OAUTH_ALLOWED_ROLES,
|
||||||
|
OAUTH_ADMIN_ROLES,
|
||||||
SHOW_ADMIN_DETAILS,
|
SHOW_ADMIN_DETAILS,
|
||||||
USER_PERMISSIONS,
|
USER_PERMISSIONS,
|
||||||
WEBHOOK_URL,
|
WEBHOOK_URL,
|
||||||
@ -93,6 +97,11 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
|||||||
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||||
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||||
|
|
||||||
|
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||||
|
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||||
|
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
||||||
|
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
app.state.TOOLS = {}
|
app.state.TOOLS = {}
|
||||||
app.state.FUNCTIONS = {}
|
app.state.FUNCTIONS = {}
|
||||||
|
@ -383,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
OAUTH_PICTURE_CLAIM = PersistentConfig(
|
OAUTH_PICTURE_CLAIM = PersistentConfig(
|
||||||
"OAUTH_USERNAME_CLAIM",
|
"OAUTH_PICTURE_CLAIM",
|
||||||
"oauth.oidc.avatar_claim",
|
"oauth.oidc.avatar_claim",
|
||||||
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
|
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
|
||||||
)
|
)
|
||||||
@ -394,6 +394,29 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
|
|||||||
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
|
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
|
||||||
|
"ENABLE_OAUTH_ROLE_MANAGEMENT",
|
||||||
|
"oauth.enable_role_mapping",
|
||||||
|
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||||
|
"OAUTH_ROLES_CLAIM",
|
||||||
|
"oauth.roles_claim",
|
||||||
|
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_ALLOWED_ROLES = PersistentConfig(
|
||||||
|
"OAUTH_ALLOWED_ROLES",
|
||||||
|
"oauth.allowed_roles",
|
||||||
|
[role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")],
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_ADMIN_ROLES = PersistentConfig(
|
||||||
|
"OAUTH_ADMIN_ROLES",
|
||||||
|
"oauth.admin_roles",
|
||||||
|
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
|
||||||
|
)
|
||||||
|
|
||||||
def load_oauth_providers():
|
def load_oauth_providers():
|
||||||
OAUTH_PROVIDERS.clear()
|
OAUTH_PROVIDERS.clear()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import base64
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -7,103 +7,11 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from open_webui.apps.ollama.main import (
|
|
||||||
app as ollama_app,
|
|
||||||
get_all_models as get_ollama_models,
|
|
||||||
generate_chat_completion as generate_ollama_chat_completion,
|
|
||||||
generate_openai_chat_completion as generate_ollama_openai_chat_completion,
|
|
||||||
GenerateChatCompletionForm,
|
|
||||||
)
|
|
||||||
from open_webui.apps.openai.main import (
|
|
||||||
app as openai_app,
|
|
||||||
generate_chat_completion as generate_openai_chat_completion,
|
|
||||||
get_all_models as get_openai_models,
|
|
||||||
)
|
|
||||||
|
|
||||||
from open_webui.apps.retrieval.main import app as retrieval_app
|
|
||||||
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
|
||||||
|
|
||||||
from open_webui.apps.socket.main import (
|
|
||||||
app as socket_app,
|
|
||||||
periodic_usage_pool_cleanup,
|
|
||||||
get_event_call,
|
|
||||||
get_event_emitter,
|
|
||||||
)
|
|
||||||
|
|
||||||
from open_webui.apps.webui.main import (
|
|
||||||
app as webui_app,
|
|
||||||
generate_function_chat_completion,
|
|
||||||
get_pipe_models,
|
|
||||||
)
|
|
||||||
from open_webui.apps.webui.internal.db import Session
|
|
||||||
|
|
||||||
from open_webui.apps.webui.models.auths import Auths
|
|
||||||
from open_webui.apps.webui.models.functions import Functions
|
|
||||||
from open_webui.apps.webui.models.models import Models
|
|
||||||
from open_webui.apps.webui.models.users import UserModel, Users
|
|
||||||
|
|
||||||
from open_webui.apps.webui.utils import load_function_module_by_id
|
|
||||||
|
|
||||||
from open_webui.apps.audio.main import app as audio_app
|
|
||||||
from open_webui.apps.images.main import app as images_app
|
|
||||||
|
|
||||||
from authlib.integrations.starlette_client import OAuth
|
|
||||||
from authlib.oidc.core import UserInfo
|
|
||||||
|
|
||||||
|
|
||||||
from open_webui.config import (
|
|
||||||
CACHE_DIR,
|
|
||||||
CORS_ALLOW_ORIGIN,
|
|
||||||
DEFAULT_LOCALE,
|
|
||||||
ENABLE_ADMIN_CHAT_ACCESS,
|
|
||||||
ENABLE_ADMIN_EXPORT,
|
|
||||||
ENABLE_MODEL_FILTER,
|
|
||||||
ENABLE_OAUTH_SIGNUP,
|
|
||||||
ENABLE_OLLAMA_API,
|
|
||||||
ENABLE_OPENAI_API,
|
|
||||||
ENV,
|
|
||||||
FRONTEND_BUILD_DIR,
|
|
||||||
MODEL_FILTER_LIST,
|
|
||||||
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
|
||||||
OAUTH_PROVIDERS,
|
|
||||||
ENABLE_SEARCH_QUERY,
|
|
||||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
||||||
STATIC_DIR,
|
|
||||||
TASK_MODEL,
|
|
||||||
TASK_MODEL_EXTERNAL,
|
|
||||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
||||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
||||||
WEBHOOK_URL,
|
|
||||||
WEBUI_AUTH,
|
|
||||||
WEBUI_NAME,
|
|
||||||
AppConfig,
|
|
||||||
run_migrations,
|
|
||||||
reset_config,
|
|
||||||
)
|
|
||||||
from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
|
|
||||||
from open_webui.env import (
|
|
||||||
CHANGELOG,
|
|
||||||
GLOBAL_LOG_LEVEL,
|
|
||||||
SAFE_MODE,
|
|
||||||
SRC_LOG_LEVELS,
|
|
||||||
VERSION,
|
|
||||||
WEBUI_BUILD_HASH,
|
|
||||||
WEBUI_SECRET_KEY,
|
|
||||||
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
||||||
WEBUI_SESSION_COOKIE_SECURE,
|
|
||||||
WEBUI_URL,
|
|
||||||
RESET_CONFIG_ON_START,
|
|
||||||
OFFLINE_MODE,
|
|
||||||
)
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
Depends,
|
Depends,
|
||||||
FastAPI,
|
FastAPI,
|
||||||
@ -122,16 +30,92 @@ from sqlalchemy import text
|
|||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.responses import RedirectResponse, Response, StreamingResponse
|
from starlette.responses import Response, StreamingResponse
|
||||||
|
|
||||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
|
||||||
|
|
||||||
|
from open_webui.apps.audio.main import app as audio_app
|
||||||
|
from open_webui.apps.images.main import app as images_app
|
||||||
|
from open_webui.apps.ollama.main import (
|
||||||
|
app as ollama_app,
|
||||||
|
get_all_models as get_ollama_models,
|
||||||
|
generate_chat_completion as generate_ollama_chat_completion,
|
||||||
|
GenerateChatCompletionForm,
|
||||||
|
)
|
||||||
|
from open_webui.apps.openai.main import (
|
||||||
|
app as openai_app,
|
||||||
|
generate_chat_completion as generate_openai_chat_completion,
|
||||||
|
get_all_models as get_openai_models,
|
||||||
|
)
|
||||||
|
from open_webui.apps.retrieval.main import app as retrieval_app
|
||||||
|
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
||||||
|
from open_webui.apps.socket.main import (
|
||||||
|
app as socket_app,
|
||||||
|
periodic_usage_pool_cleanup,
|
||||||
|
get_event_call,
|
||||||
|
get_event_emitter,
|
||||||
|
)
|
||||||
|
from open_webui.apps.webui.internal.db import Session
|
||||||
|
from open_webui.apps.webui.main import (
|
||||||
|
app as webui_app,
|
||||||
|
generate_function_chat_completion,
|
||||||
|
get_pipe_models,
|
||||||
|
)
|
||||||
|
from open_webui.apps.webui.models.functions import Functions
|
||||||
|
from open_webui.apps.webui.models.models import Models
|
||||||
|
from open_webui.apps.webui.models.users import UserModel, Users
|
||||||
|
from open_webui.apps.webui.utils import load_function_module_by_id
|
||||||
|
from open_webui.config import (
|
||||||
|
CACHE_DIR,
|
||||||
|
CORS_ALLOW_ORIGIN,
|
||||||
|
DEFAULT_LOCALE,
|
||||||
|
ENABLE_ADMIN_CHAT_ACCESS,
|
||||||
|
ENABLE_ADMIN_EXPORT,
|
||||||
|
ENABLE_MODEL_FILTER,
|
||||||
|
ENABLE_OLLAMA_API,
|
||||||
|
ENABLE_OPENAI_API,
|
||||||
|
ENV,
|
||||||
|
FRONTEND_BUILD_DIR,
|
||||||
|
MODEL_FILTER_LIST,
|
||||||
|
OAUTH_PROVIDERS,
|
||||||
|
ENABLE_SEARCH_QUERY,
|
||||||
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
STATIC_DIR,
|
||||||
|
TASK_MODEL,
|
||||||
|
TASK_MODEL_EXTERNAL,
|
||||||
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||||
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
WEBHOOK_URL,
|
||||||
|
WEBUI_AUTH,
|
||||||
|
WEBUI_NAME,
|
||||||
|
AppConfig,
|
||||||
|
reset_config,
|
||||||
|
)
|
||||||
|
from open_webui.constants import TASKS
|
||||||
|
from open_webui.env import (
|
||||||
|
CHANGELOG,
|
||||||
|
GLOBAL_LOG_LEVEL,
|
||||||
|
SAFE_MODE,
|
||||||
|
SRC_LOG_LEVELS,
|
||||||
|
VERSION,
|
||||||
|
WEBUI_BUILD_HASH,
|
||||||
|
WEBUI_SECRET_KEY,
|
||||||
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
||||||
|
WEBUI_SESSION_COOKIE_SECURE,
|
||||||
|
WEBUI_URL,
|
||||||
|
RESET_CONFIG_ON_START,
|
||||||
|
OFFLINE_MODE,
|
||||||
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
get_last_user_message,
|
get_last_user_message,
|
||||||
parse_duration,
|
|
||||||
prepend_to_first_user_message_content,
|
prepend_to_first_user_message_content,
|
||||||
)
|
)
|
||||||
|
from open_webui.utils.oauth import oauth_manager
|
||||||
|
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
||||||
|
from open_webui.utils.response import (
|
||||||
|
convert_response_ollama_to_openai,
|
||||||
|
convert_streaming_response_ollama_to_openai,
|
||||||
|
)
|
||||||
|
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
moa_response_generation_template,
|
moa_response_generation_template,
|
||||||
search_query_generation_template,
|
search_query_generation_template,
|
||||||
@ -140,27 +124,17 @@ from open_webui.utils.task import (
|
|||||||
)
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
from open_webui.utils.utils import (
|
from open_webui.utils.utils import (
|
||||||
create_token,
|
|
||||||
decode_token,
|
decode_token,
|
||||||
get_admin_user,
|
get_admin_user,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
get_http_authorization_cred,
|
get_http_authorization_cred,
|
||||||
get_password_hash,
|
|
||||||
get_verified_user,
|
get_verified_user,
|
||||||
)
|
)
|
||||||
from open_webui.utils.webhook import post_webhook
|
|
||||||
|
|
||||||
from open_webui.utils.payload import convert_payload_openai_to_ollama
|
|
||||||
from open_webui.utils.response import (
|
|
||||||
convert_response_ollama_to_openai,
|
|
||||||
convert_streaming_response_ollama_to_openai,
|
|
||||||
)
|
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
print("SAFE MODE ENABLED")
|
print("SAFE MODE ENABLED")
|
||||||
Functions.deactivate_all_functions()
|
Functions.deactivate_all_functions()
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
@ -217,7 +191,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|||||||
|
|
||||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
|
|
||||||
|
|
||||||
app.state.config.TASK_MODEL = TASK_MODEL
|
app.state.config.TASK_MODEL = TASK_MODEL
|
||||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||||
@ -232,6 +205,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# ChatCompletion Middleware
|
# ChatCompletion Middleware
|
||||||
@ -245,14 +220,14 @@ def get_task_model_id(default_model_id):
|
|||||||
# Check if the user has a custom task model and use that model
|
# Check if the user has a custom task model and use that model
|
||||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||||
if (
|
if (
|
||||||
app.state.config.TASK_MODEL
|
app.state.config.TASK_MODEL
|
||||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||||
):
|
):
|
||||||
task_model_id = app.state.config.TASK_MODEL
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
app.state.config.TASK_MODEL_EXTERNAL
|
app.state.config.TASK_MODEL_EXTERNAL
|
||||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||||
):
|
):
|
||||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
|
||||||
@ -389,7 +364,7 @@ async def get_content_from_response(response) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: UserModel, extra_params: dict
|
body: dict, user: UserModel, extra_params: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
# If tool_ids field is present, call the functions
|
# If tool_ids field is present, call the functions
|
||||||
metadata = body.get("metadata", {})
|
metadata = body.get("metadata", {})
|
||||||
@ -690,6 +665,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
app.add_middleware(ChatCompletionMiddleware)
|
app.add_middleware(ChatCompletionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# Pipeline Middleware
|
# Pipeline Middleware
|
||||||
@ -702,15 +678,15 @@ def get_sorted_filters(model_id):
|
|||||||
model
|
model
|
||||||
for model in app.state.MODELS.values()
|
for model in app.state.MODELS.values()
|
||||||
if "pipeline" in model
|
if "pipeline" in model
|
||||||
and "type" in model["pipeline"]
|
and "type" in model["pipeline"]
|
||||||
and model["pipeline"]["type"] == "filter"
|
and model["pipeline"]["type"] == "filter"
|
||||||
and (
|
and (
|
||||||
model["pipeline"]["pipelines"] == ["*"]
|
model["pipeline"]["pipelines"] == ["*"]
|
||||||
or any(
|
or any(
|
||||||
model_id == target_model_id
|
model_id == target_model_id
|
||||||
for target_model_id in model["pipeline"]["pipelines"]
|
for target_model_id in model["pipeline"]["pipelines"]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||||
return sorted_filters
|
return sorted_filters
|
||||||
@ -896,8 +872,8 @@ async def update_embedding_function(request: Request, call_next):
|
|||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def inspect_websocket(request: Request, call_next):
|
async def inspect_websocket(request: Request, call_next):
|
||||||
if (
|
if (
|
||||||
"/ws/socket.io" in request.url.path
|
"/ws/socket.io" in request.url.path
|
||||||
and request.query_params.get("transport") == "websocket"
|
and request.query_params.get("transport") == "websocket"
|
||||||
):
|
):
|
||||||
upgrade = (request.headers.get("Upgrade") or "").lower()
|
upgrade = (request.headers.get("Upgrade") or "").lower()
|
||||||
connection = (request.headers.get("Connection") or "").lower().split(",")
|
connection = (request.headers.get("Connection") or "").lower().split(",")
|
||||||
@ -966,8 +942,8 @@ async def get_all_models():
|
|||||||
if custom_model.base_model_id is None:
|
if custom_model.base_model_id is None:
|
||||||
for model in models:
|
for model in models:
|
||||||
if (
|
if (
|
||||||
custom_model.id == model["id"]
|
custom_model.id == model["id"]
|
||||||
or custom_model.id == model["id"].split(":")[0]
|
or custom_model.id == model["id"].split(":")[0]
|
||||||
):
|
):
|
||||||
model["name"] = custom_model.name
|
model["name"] = custom_model.name
|
||||||
model["info"] = custom_model.model_dump()
|
model["info"] = custom_model.model_dump()
|
||||||
@ -984,8 +960,8 @@ async def get_all_models():
|
|||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if (
|
if (
|
||||||
custom_model.base_model_id == model["id"]
|
custom_model.base_model_id == model["id"]
|
||||||
or custom_model.base_model_id == model["id"].split(":")[0]
|
or custom_model.base_model_id == model["id"].split(":")[0]
|
||||||
):
|
):
|
||||||
owned_by = model["owned_by"]
|
owned_by = model["owned_by"]
|
||||||
if "pipe" in model:
|
if "pipe" in model:
|
||||||
@ -1785,7 +1761,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
|||||||
|
|
||||||
@app.post("/api/pipelines/upload")
|
@app.post("/api/pipelines/upload")
|
||||||
async def upload_pipeline(
|
async def upload_pipeline(
|
||||||
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
print("upload_pipeline", urlIdx, file.filename)
|
print("upload_pipeline", urlIdx, file.filename)
|
||||||
# Check if the uploaded file is a python file
|
# Check if the uploaded file is a python file
|
||||||
@ -1962,9 +1938,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
|||||||
|
|
||||||
@app.get("/api/pipelines/{pipeline_id}/valves")
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
||||||
async def get_pipeline_valves(
|
async def get_pipeline_valves(
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
pipeline_id: str,
|
pipeline_id: str,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
@ -2000,9 +1976,9 @@ async def get_pipeline_valves(
|
|||||||
|
|
||||||
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
||||||
async def get_pipeline_valves_spec(
|
async def get_pipeline_valves_spec(
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
pipeline_id: str,
|
pipeline_id: str,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
@ -2037,10 +2013,10 @@ async def get_pipeline_valves_spec(
|
|||||||
|
|
||||||
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
||||||
async def update_pipeline_valves(
|
async def update_pipeline_valves(
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
pipeline_id: str,
|
pipeline_id: str,
|
||||||
form_data: dict,
|
form_data: dict,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
r = None
|
r = None
|
||||||
try:
|
try:
|
||||||
@ -2164,7 +2140,7 @@ class ModelFilterConfigForm(BaseModel):
|
|||||||
|
|
||||||
@app.post("/api/config/model/filter")
|
@app.post("/api/config/model/filter")
|
||||||
async def update_model_filter_config(
|
async def update_model_filter_config(
|
||||||
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
||||||
app.state.config.MODEL_FILTER_LIST = form_data.models
|
app.state.config.MODEL_FILTER_LIST = form_data.models
|
||||||
@ -2219,7 +2195,7 @@ async def get_app_latest_release_version():
|
|||||||
timeout = aiohttp.ClientTimeout(total=1)
|
timeout = aiohttp.ClientTimeout(total=1)
|
||||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
@ -2235,20 +2211,6 @@ async def get_app_latest_release_version():
|
|||||||
# OAuth Login & Callback
|
# OAuth Login & Callback
|
||||||
############################
|
############################
|
||||||
|
|
||||||
oauth = OAuth()
|
|
||||||
|
|
||||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
|
||||||
oauth.register(
|
|
||||||
name=provider_name,
|
|
||||||
client_id=provider_config["client_id"],
|
|
||||||
client_secret=provider_config["client_secret"],
|
|
||||||
server_metadata_url=provider_config["server_metadata_url"],
|
|
||||||
client_kwargs={
|
|
||||||
"scope": provider_config["scope"],
|
|
||||||
},
|
|
||||||
redirect_uri=provider_config["redirect_uri"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# SessionMiddleware is used by authlib for oauth
|
# SessionMiddleware is used by authlib for oauth
|
||||||
if len(OAUTH_PROVIDERS) > 0:
|
if len(OAUTH_PROVIDERS) > 0:
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -2262,16 +2224,7 @@ if len(OAUTH_PROVIDERS) > 0:
|
|||||||
|
|
||||||
@app.get("/oauth/{provider}/login")
|
@app.get("/oauth/{provider}/login")
|
||||||
async def oauth_login(provider: str, request: Request):
|
async def oauth_login(provider: str, request: Request):
|
||||||
if provider not in OAUTH_PROVIDERS:
|
return await oauth_manager.handle_login(provider, request)
|
||||||
raise HTTPException(404)
|
|
||||||
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
|
||||||
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
|
||||||
"oauth_callback", provider=provider
|
|
||||||
)
|
|
||||||
client = oauth.create_client(provider)
|
|
||||||
if client is None:
|
|
||||||
raise HTTPException(404)
|
|
||||||
return await client.authorize_redirect(request, redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
# OAuth login logic is as follows:
|
# OAuth login logic is as follows:
|
||||||
@ -2282,118 +2235,7 @@ async def oauth_login(provider: str, request: Request):
|
|||||||
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
||||||
@app.get("/oauth/{provider}/callback")
|
@app.get("/oauth/{provider}/callback")
|
||||||
async def oauth_callback(provider: str, request: Request, response: Response):
|
async def oauth_callback(provider: str, request: Request, response: Response):
|
||||||
if provider not in OAUTH_PROVIDERS:
|
return await oauth_manager.handle_callback(provider, request, response)
|
||||||
raise HTTPException(404)
|
|
||||||
client = oauth.create_client(provider)
|
|
||||||
try:
|
|
||||||
token = await client.authorize_access_token(request)
|
|
||||||
except Exception as e:
|
|
||||||
log.warning(f"OAuth callback error: {e}")
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
||||||
user_data: UserInfo = token["userinfo"]
|
|
||||||
|
|
||||||
sub = user_data.get("sub")
|
|
||||||
if not sub:
|
|
||||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
|
||||||
email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
|
|
||||||
email = user_data.get(email_claim, "").lower()
|
|
||||||
# We currently mandate that email addresses are provided
|
|
||||||
if not email:
|
|
||||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
||||||
|
|
||||||
# Check if the user exists
|
|
||||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
# If the user does not exist, check if merging is enabled
|
|
||||||
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
|
||||||
# Check if the user exists by email
|
|
||||||
user = Users.get_user_by_email(email)
|
|
||||||
if user:
|
|
||||||
# Update the user with the new oauth sub
|
|
||||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
# If the user does not exist, check if signups are enabled
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
|
||||||
# Check if an existing user with the same email already exists
|
|
||||||
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
|
||||||
if existing_user:
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
||||||
|
|
||||||
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
|
|
||||||
picture_url = user_data.get(picture_claim, "")
|
|
||||||
if picture_url:
|
|
||||||
# Download the profile image into a base64 string
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(picture_url) as resp:
|
|
||||||
picture = await resp.read()
|
|
||||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
|
||||||
if guessed_mime_type is None:
|
|
||||||
# assume JPG, browsers are tolerant enough of image formats
|
|
||||||
guessed_mime_type = "image/jpeg"
|
|
||||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Error downloading profile image '{picture_url}': {e}")
|
|
||||||
picture_url = ""
|
|
||||||
if not picture_url:
|
|
||||||
picture_url = "/user.png"
|
|
||||||
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
|
||||||
role = (
|
|
||||||
"admin"
|
|
||||||
if Users.get_num_users() == 0
|
|
||||||
else webui_app.state.config.DEFAULT_USER_ROLE
|
|
||||||
)
|
|
||||||
user = Auths.insert_new_auth(
|
|
||||||
email=email,
|
|
||||||
password=get_password_hash(
|
|
||||||
str(uuid.uuid4())
|
|
||||||
), # Random password, not used
|
|
||||||
name=user_data.get(username_claim, "User"),
|
|
||||||
profile_image_url=picture_url,
|
|
||||||
role=role,
|
|
||||||
oauth_sub=provider_sub,
|
|
||||||
)
|
|
||||||
|
|
||||||
if webui_app.state.config.WEBHOOK_URL:
|
|
||||||
post_webhook(
|
|
||||||
webui_app.state.config.WEBHOOK_URL,
|
|
||||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
||||||
{
|
|
||||||
"action": "signup",
|
|
||||||
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
||||||
"user": user.model_dump_json(exclude_none=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
||||||
)
|
|
||||||
|
|
||||||
jwt_token = create_token(
|
|
||||||
data={"id": user.id},
|
|
||||||
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the cookie token
|
|
||||||
response.set_cookie(
|
|
||||||
key="token",
|
|
||||||
value=jwt_token,
|
|
||||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
|
||||||
samesite=WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
||||||
secure=WEBUI_SESSION_COOKIE_SECURE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect back to the frontend with the JWT token
|
|
||||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
|
||||||
return RedirectResponse(url=redirect_url)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/manifest.json")
|
@app.get("/manifest.json")
|
||||||
|
243
backend/open_webui/utils/oauth.py
Normal file
243
backend/open_webui/utils/oauth.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
from authlib.oidc.core import UserInfo
|
||||||
|
from fastapi import (
|
||||||
|
HTTPException,
|
||||||
|
status,
|
||||||
|
)
|
||||||
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
|
from open_webui.apps.webui.models.auths import Auths
|
||||||
|
from open_webui.apps.webui.models.users import Users
|
||||||
|
from open_webui.config import (
|
||||||
|
DEFAULT_USER_ROLE,
|
||||||
|
ENABLE_OAUTH_SIGNUP,
|
||||||
|
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||||
|
OAUTH_PROVIDERS,
|
||||||
|
ENABLE_OAUTH_ROLE_MANAGEMENT,
|
||||||
|
OAUTH_ROLES_CLAIM,
|
||||||
|
OAUTH_EMAIL_CLAIM,
|
||||||
|
OAUTH_PICTURE_CLAIM,
|
||||||
|
OAUTH_USERNAME_CLAIM,
|
||||||
|
OAUTH_ALLOWED_ROLES,
|
||||||
|
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
|
||||||
|
)
|
||||||
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
from open_webui.utils.misc import parse_duration
|
||||||
|
from open_webui.utils.utils import get_password_hash, create_token
|
||||||
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
auth_manager_config = AppConfig()
|
||||||
|
auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||||
|
auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP
|
||||||
|
auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL
|
||||||
|
auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||||
|
auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||||
|
auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||||
|
auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||||
|
auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
||||||
|
auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
||||||
|
auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
||||||
|
auth_manager_config.WEBHOOK_URL = WEBHOOK_URL
|
||||||
|
auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.oauth = OAuth()
|
||||||
|
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||||
|
self.oauth.register(
|
||||||
|
name=provider_name,
|
||||||
|
client_id=provider_config["client_id"],
|
||||||
|
client_secret=provider_config["client_secret"],
|
||||||
|
server_metadata_url=provider_config["server_metadata_url"],
|
||||||
|
client_kwargs={
|
||||||
|
"scope": provider_config["scope"],
|
||||||
|
},
|
||||||
|
redirect_uri=provider_config["redirect_uri"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_client(self, provider_name):
|
||||||
|
return self.oauth.create_client(provider_name)
|
||||||
|
|
||||||
|
def get_user_role(self, user, user_data):
|
||||||
|
if user and Users.get_num_users() == 1:
|
||||||
|
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||||
|
return "admin"
|
||||||
|
if not user and Users.get_num_users() == 0:
|
||||||
|
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||||
|
return "admin"
|
||||||
|
|
||||||
|
if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
||||||
|
oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM
|
||||||
|
oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES
|
||||||
|
oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES
|
||||||
|
oauth_roles = None
|
||||||
|
role = "pending" # Default/fallback role if no matching roles are found
|
||||||
|
|
||||||
|
# Next block extracts the roles from the user data, accepting nested claims of any depth
|
||||||
|
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
||||||
|
claim_data = user_data
|
||||||
|
nested_claims = oauth_claim.split(".")
|
||||||
|
for nested_claim in nested_claims:
|
||||||
|
claim_data = claim_data.get(nested_claim, {})
|
||||||
|
oauth_roles = claim_data if isinstance(claim_data, list) else None
|
||||||
|
|
||||||
|
# If any roles are found, check if they match the allowed or admin roles
|
||||||
|
if oauth_roles:
|
||||||
|
# If role management is enabled, and matching roles are provided, use the roles
|
||||||
|
for allowed_role in oauth_allowed_roles:
|
||||||
|
# If the user has any of the allowed roles, assign the role "user"
|
||||||
|
if allowed_role in oauth_roles:
|
||||||
|
role = "user"
|
||||||
|
break
|
||||||
|
for admin_role in oauth_admin_roles:
|
||||||
|
# If the user has any of the admin roles, assign the role "admin"
|
||||||
|
if admin_role in oauth_roles:
|
||||||
|
role = "admin"
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if not user:
|
||||||
|
# If role management is disabled, use the default role for new users
|
||||||
|
role = auth_manager_config.DEFAULT_USER_ROLE
|
||||||
|
else:
|
||||||
|
# If role management is disabled, use the existing role for existing users
|
||||||
|
role = user.role
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
async def handle_login(self, provider, request):
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise HTTPException(404)
|
||||||
|
# If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
||||||
|
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
||||||
|
"oauth_callback", provider=provider
|
||||||
|
)
|
||||||
|
client = self.get_client(provider)
|
||||||
|
if client is None:
|
||||||
|
raise HTTPException(404)
|
||||||
|
return await client.authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
async def handle_callback(self, provider, request, response):
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise HTTPException(404)
|
||||||
|
client = self.get_client(provider)
|
||||||
|
try:
|
||||||
|
token = await client.authorize_access_token(request)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"OAuth callback error: {e}")
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
user_data: UserInfo = token["userinfo"]
|
||||||
|
|
||||||
|
sub = user_data.get("sub")
|
||||||
|
if not sub:
|
||||||
|
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
provider_sub = f"{provider}@{sub}"
|
||||||
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
|
email = user_data.get(email_claim, "").lower()
|
||||||
|
# We currently mandate that email addresses are provided
|
||||||
|
if not email:
|
||||||
|
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
# Check if the user exists
|
||||||
|
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
# If the user does not exist, check if merging is enabled
|
||||||
|
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
||||||
|
# Check if the user exists by email
|
||||||
|
user = Users.get_user_by_email(email)
|
||||||
|
if user:
|
||||||
|
# Update the user with the new oauth sub
|
||||||
|
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||||
|
|
||||||
|
if user:
|
||||||
|
determined_role = self.get_user_role(user, user_data)
|
||||||
|
if user.role != determined_role:
|
||||||
|
Users.update_user_role_by_id(user.id, determined_role)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
# If the user does not exist, check if signups are enabled
|
||||||
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
|
||||||
|
# Check if an existing user with the same email already exists
|
||||||
|
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
|
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
||||||
|
picture_url = user_data.get(picture_claim, "")
|
||||||
|
if picture_url:
|
||||||
|
# Download the profile image into a base64 string
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(picture_url) as resp:
|
||||||
|
picture = await resp.read()
|
||||||
|
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||||
|
if guessed_mime_type is None:
|
||||||
|
# assume JPG, browsers are tolerant enough of image formats
|
||||||
|
guessed_mime_type = "image/jpeg"
|
||||||
|
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error downloading profile image '{picture_url}': {e}")
|
||||||
|
picture_url = ""
|
||||||
|
if not picture_url:
|
||||||
|
picture_url = "/user.png"
|
||||||
|
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||||
|
|
||||||
|
role = self.get_user_role(None, user_data)
|
||||||
|
|
||||||
|
user = Auths.insert_new_auth(
|
||||||
|
email=email,
|
||||||
|
password=get_password_hash(
|
||||||
|
str(uuid.uuid4())
|
||||||
|
), # Random password, not used
|
||||||
|
name=user_data.get(username_claim, "User"),
|
||||||
|
profile_image_url=picture_url,
|
||||||
|
role=role,
|
||||||
|
oauth_sub=provider_sub,
|
||||||
|
)
|
||||||
|
|
||||||
|
if auth_manager_config.WEBHOOK_URL:
|
||||||
|
post_webhook(
|
||||||
|
auth_manager_config.WEBHOOK_URL,
|
||||||
|
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
|
{
|
||||||
|
"action": "signup",
|
||||||
|
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
|
"user": user.model_dump_json(exclude_none=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
|
)
|
||||||
|
|
||||||
|
jwt_token = create_token(
|
||||||
|
data={"id": user.id},
|
||||||
|
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the cookie token
|
||||||
|
response.set_cookie(
|
||||||
|
key="token",
|
||||||
|
value=jwt_token,
|
||||||
|
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect back to the frontend with the JWT token
|
||||||
|
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||||
|
return RedirectResponse(url=redirect_url)
|
||||||
|
|
||||||
|
oauth_manager = OAuthManager()
|
Loading…
Reference in New Issue
Block a user