refac: move things around, uplift oauth endpoints

This commit is contained in:
Jun Siang Cheah 2024-05-27 18:07:38 +01:00
parent 06dbf59742
commit 985fdca585
5 changed files with 142 additions and 100 deletions

View File

@ -58,12 +58,6 @@ app.add_middleware(
allow_headers=["*"],
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])

View File

@ -112,9 +112,16 @@ class UsersTable:
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
def get_user_by_email(
self, email: str, oauth_user: bool = False
) -> Optional[UserModel]:
try:
user = User.get((User.email == email, User.oauth_sub.is_null()))
conditions = (
(User.email == email, User.oauth_sub.is_null())
if not oauth_user
else (User.email == email)
)
user = User.get(conditions)
return UserModel(**model_to_dict(user))
except:
return None
@ -177,6 +184,18 @@ class UsersTable:
except:
return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)

View File

@ -1,7 +1,5 @@
import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status
@ -11,8 +9,6 @@ import re
import uuid
import csv
from starlette.responses import RedirectResponse
from apps.webui.models.auths import (
SigninForm,
SignupForm,
@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import (
WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
)
router = APIRouter()
@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
@router.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@router.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, create a new user if signup is enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=request.app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)

View File

@ -1,4 +1,8 @@
import uuid
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
@ -31,8 +36,16 @@ import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.models import Models, ModelModel
from utils.utils import get_admin_user, get_verified_user
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from utils.misc import parse_duration
from utils.utils import (
get_admin_user,
get_verified_user,
get_password_hash,
create_token,
)
from apps.rag.utils import rag_messages
from config import (
@ -56,8 +69,12 @@ from config import (
ENABLE_ADMIN_EXPORT,
AppConfig,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
)
from constants import ERROR_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
@ -453,6 +470,103 @@ async def get_app_latest_release_version():
)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
email = user_data.get("email", "").lower()
if not email:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=webui_app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json")
async def get_manifest_json():
return {

View File

@ -259,7 +259,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
@ -284,7 +284,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
}}
>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
@ -309,7 +309,7 @@
<button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`;
window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
}}
>
<svg