mirror of
https://github.com/open-webui/open-webui
synced 2025-06-14 10:20:52 +00:00
refac: move things around, uplift oauth endpoints
This commit is contained in:
parent
06dbf59742
commit
985fdca585
@ -58,12 +58,6 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# SessionMiddleware is used by authlib for oauth
|
|
||||||
if len(OAUTH_PROVIDERS) > 0:
|
|
||||||
app.add_middleware(
|
|
||||||
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||||
|
@ -112,9 +112,16 @@ class UsersTable:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_email(self, email: str) -> Optional[UserModel]:
|
def get_user_by_email(
|
||||||
|
self, email: str, oauth_user: bool = False
|
||||||
|
) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
user = User.get((User.email == email, User.oauth_sub.is_null()))
|
conditions = (
|
||||||
|
(User.email == email, User.oauth_sub.is_null())
|
||||||
|
if not oauth_user
|
||||||
|
else (User.email == email)
|
||||||
|
)
|
||||||
|
user = User.get(conditions)
|
||||||
return UserModel(**model_to_dict(user))
|
return UserModel(**model_to_dict(user))
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
@ -177,6 +184,18 @@ class UsersTable:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def update_user_oauth_sub_by_id(
|
||||||
|
self, id: str, oauth_sub: str
|
||||||
|
) -> Optional[UserModel]:
|
||||||
|
try:
|
||||||
|
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
|
||||||
|
query.execute()
|
||||||
|
|
||||||
|
user = User.get(User.id == id)
|
||||||
|
return UserModel(**model_to_dict(user))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
query = User.update(**updated).where(User.id == id)
|
query = User.update(**updated).where(User.id == id)
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from authlib.integrations.starlette_client import OAuth
|
|
||||||
from authlib.oidc.core import UserInfo
|
|
||||||
from fastapi import Request, UploadFile, File
|
from fastapi import Request, UploadFile, File
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
|
|
||||||
@ -11,8 +9,6 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
from starlette.responses import RedirectResponse
|
|
||||||
|
|
||||||
from apps.webui.models.auths import (
|
from apps.webui.models.auths import (
|
||||||
SigninForm,
|
SigninForm,
|
||||||
SignupForm,
|
SignupForm,
|
||||||
@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
|||||||
from config import (
|
from config import (
|
||||||
WEBUI_AUTH,
|
WEBUI_AUTH,
|
||||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||||
OAUTH_PROVIDERS,
|
|
||||||
ENABLE_OAUTH_SIGNUP,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
############################
|
|
||||||
# OAuth Login & Callback
|
|
||||||
############################
|
|
||||||
|
|
||||||
oauth = OAuth()
|
|
||||||
|
|
||||||
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
|
||||||
oauth.register(
|
|
||||||
name=provider_name,
|
|
||||||
client_id=provider_config["client_id"],
|
|
||||||
client_secret=provider_config["client_secret"],
|
|
||||||
server_metadata_url=provider_config["server_metadata_url"],
|
|
||||||
client_kwargs={
|
|
||||||
"scope": provider_config["scope"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/oauth/{provider}/login")
|
|
||||||
async def oauth_login(provider: str, request: Request):
|
|
||||||
if provider not in OAUTH_PROVIDERS:
|
|
||||||
raise HTTPException(404)
|
|
||||||
redirect_uri = request.url_for("oauth_callback", provider=provider)
|
|
||||||
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/oauth/{provider}/callback")
|
|
||||||
async def oauth_callback(provider: str, request: Request):
|
|
||||||
if provider not in OAUTH_PROVIDERS:
|
|
||||||
raise HTTPException(404)
|
|
||||||
client = oauth.create_client(provider)
|
|
||||||
token = await client.authorize_access_token(request)
|
|
||||||
user_data: UserInfo = token["userinfo"]
|
|
||||||
|
|
||||||
sub = user_data.get("sub")
|
|
||||||
if not sub:
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
|
||||||
|
|
||||||
# Check if the user exists
|
|
||||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
# If the user does not exist, create a new user if signup is enabled
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
|
||||||
user = Auths.insert_new_auth(
|
|
||||||
email=user_data.get("email", "").lower(),
|
|
||||||
password=get_password_hash(
|
|
||||||
str(uuid.uuid4())
|
|
||||||
), # Random password, not used
|
|
||||||
name=user_data.get("name", "User"),
|
|
||||||
profile_image_url=user_data.get("picture", "/user.png"),
|
|
||||||
role=request.app.state.config.DEFAULT_USER_ROLE,
|
|
||||||
oauth_sub=provider_sub,
|
|
||||||
)
|
|
||||||
|
|
||||||
if request.app.state.config.WEBHOOK_URL:
|
|
||||||
post_webhook(
|
|
||||||
request.app.state.config.WEBHOOK_URL,
|
|
||||||
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
||||||
{
|
|
||||||
"action": "signup",
|
|
||||||
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
||||||
"user": user.model_dump_json(exclude_none=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
||||||
|
|
||||||
jwt_token = create_token(
|
|
||||||
data={"id": user.id},
|
|
||||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect back to the frontend with the JWT token
|
|
||||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
|
||||||
return RedirectResponse(url=redirect_url)
|
|
||||||
|
122
backend/main.py
122
backend/main.py
@ -1,4 +1,8 @@
|
|||||||
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
from authlib.oidc.core import UserInfo
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
import json
|
import json
|
||||||
import markdown
|
import markdown
|
||||||
@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.responses import StreamingResponse, Response
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
from starlette.responses import StreamingResponse, Response, RedirectResponse
|
||||||
|
|
||||||
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
|
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
|
||||||
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
|
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
|
||||||
@ -31,8 +36,16 @@ import asyncio
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from apps.webui.models.models import Models, ModelModel
|
from apps.webui.models.auths import Auths
|
||||||
from utils.utils import get_admin_user, get_verified_user
|
from apps.webui.models.models import Models
|
||||||
|
from apps.webui.models.users import Users
|
||||||
|
from utils.misc import parse_duration
|
||||||
|
from utils.utils import (
|
||||||
|
get_admin_user,
|
||||||
|
get_verified_user,
|
||||||
|
get_password_hash,
|
||||||
|
create_token,
|
||||||
|
)
|
||||||
from apps.rag.utils import rag_messages
|
from apps.rag.utils import rag_messages
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
@ -56,8 +69,12 @@ from config import (
|
|||||||
ENABLE_ADMIN_EXPORT,
|
ENABLE_ADMIN_EXPORT,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
OAUTH_PROVIDERS,
|
OAUTH_PROVIDERS,
|
||||||
|
ENABLE_OAUTH_SIGNUP,
|
||||||
|
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
||||||
|
WEBUI_SECRET_KEY,
|
||||||
)
|
)
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||||
|
from utils.webhook import post_webhook
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -453,6 +470,103 @@ async def get_app_latest_release_version():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# OAuth Login & Callback
|
||||||
|
############################
|
||||||
|
|
||||||
|
oauth = OAuth()
|
||||||
|
|
||||||
|
for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
||||||
|
oauth.register(
|
||||||
|
name=provider_name,
|
||||||
|
client_id=provider_config["client_id"],
|
||||||
|
client_secret=provider_config["client_secret"],
|
||||||
|
server_metadata_url=provider_config["server_metadata_url"],
|
||||||
|
client_kwargs={
|
||||||
|
"scope": provider_config["scope"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# SessionMiddleware is used by authlib for oauth
|
||||||
|
if len(OAUTH_PROVIDERS) > 0:
|
||||||
|
app.add_middleware(
|
||||||
|
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth/{provider}/login")
|
||||||
|
async def oauth_login(provider: str, request: Request):
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise HTTPException(404)
|
||||||
|
redirect_uri = request.url_for("oauth_callback", provider=provider)
|
||||||
|
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/oauth/{provider}/callback")
|
||||||
|
async def oauth_callback(provider: str, request: Request):
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise HTTPException(404)
|
||||||
|
client = oauth.create_client(provider)
|
||||||
|
token = await client.authorize_access_token(request)
|
||||||
|
user_data: UserInfo = token["userinfo"]
|
||||||
|
|
||||||
|
sub = user_data.get("sub")
|
||||||
|
if not sub:
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
provider_sub = f"{provider}@{sub}"
|
||||||
|
|
||||||
|
# Check if the user exists
|
||||||
|
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
# If the user does not exist, check if merging is enabled
|
||||||
|
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||||
|
# Check if the user exists by email
|
||||||
|
email = user_data.get("email", "").lower()
|
||||||
|
if not email:
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
|
||||||
|
if user:
|
||||||
|
# Update the user with the new oauth sub
|
||||||
|
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
# If the user does not exist, check if signups are enabled
|
||||||
|
if ENABLE_OAUTH_SIGNUP.value:
|
||||||
|
user = Auths.insert_new_auth(
|
||||||
|
email=user_data.get("email", "").lower(),
|
||||||
|
password=get_password_hash(
|
||||||
|
str(uuid.uuid4())
|
||||||
|
), # Random password, not used
|
||||||
|
name=user_data.get("name", "User"),
|
||||||
|
profile_image_url=user_data.get("picture", "/user.png"),
|
||||||
|
role=webui_app.state.config.DEFAULT_USER_ROLE,
|
||||||
|
oauth_sub=provider_sub,
|
||||||
|
)
|
||||||
|
|
||||||
|
if webui_app.state.config.WEBHOOK_URL:
|
||||||
|
post_webhook(
|
||||||
|
webui_app.state.config.WEBHOOK_URL,
|
||||||
|
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
|
{
|
||||||
|
"action": "signup",
|
||||||
|
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||||
|
"user": user.model_dump_json(exclude_none=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
jwt_token = create_token(
|
||||||
|
data={"id": user.id},
|
||||||
|
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect back to the frontend with the JWT token
|
||||||
|
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||||
|
return RedirectResponse(url=redirect_url)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/manifest.json")
|
@app.get("/manifest.json")
|
||||||
async def get_manifest_json():
|
async def get_manifest_json():
|
||||||
return {
|
return {
|
||||||
|
@ -259,7 +259,7 @@
|
|||||||
<button
|
<button
|
||||||
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
|
window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
|
||||||
@ -284,7 +284,7 @@
|
|||||||
<button
|
<button
|
||||||
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
|
window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
|
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
|
||||||
@ -309,7 +309,7 @@
|
|||||||
<button
|
<button
|
||||||
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
|
||||||
on:click={() => {
|
on:click={() => {
|
||||||
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
|
window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<svg
|
<svg
|
||||||
|
Loading…
Reference in New Issue
Block a user