refac: move things around, uplift oauth endpoints

This commit is contained in:
Jun Siang Cheah 2024-05-27 18:07:38 +01:00
parent 06dbf59742
commit 985fdca585
5 changed files with 142 additions and 100 deletions

View File

@ -58,12 +58,6 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])

View File

@ -112,9 +112,16 @@ class UsersTable:
except: except:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(
self, email: str, oauth_user: bool = False
) -> Optional[UserModel]:
try: try:
user = User.get((User.email == email, User.oauth_sub.is_null())) conditions = (
(User.email == email, User.oauth_sub.is_null())
if not oauth_user
else (User.email == email)
)
user = User.get(conditions)
return UserModel(**model_to_dict(user)) return UserModel(**model_to_dict(user))
except: except:
return None return None
@ -177,6 +184,18 @@ class UsersTable:
except: except:
return None return None
def update_user_oauth_sub_by_id(
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) query = User.update(**updated).where(User.id == id)

View File

@ -1,7 +1,5 @@
import logging import logging
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from fastapi import Request, UploadFile, File from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
@ -11,8 +9,6 @@ import re
import uuid import uuid
import csv import csv
from starlette.responses import RedirectResponse
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
@ -39,8 +35,6 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import ( from config import (
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
) )
router = APIRouter() router = APIRouter()
@ -381,82 +375,3 @@ async def get_api_key(user=Depends(get_current_user)):
} }
else: else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
@router.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@router.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, create a new user if signup is enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=request.app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)

View File

@ -1,4 +1,8 @@
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import json import json
import markdown import markdown
@ -17,7 +21,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models from apps.openai.main import app as openai_app, get_all_models as get_openai_models
@ -31,8 +36,16 @@ import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from apps.webui.models.models import Models, ModelModel from apps.webui.models.auths import Auths
from utils.utils import get_admin_user, get_verified_user from apps.webui.models.models import Models
from apps.webui.models.users import Users
from utils.misc import parse_duration
from utils.utils import (
get_admin_user,
get_verified_user,
get_password_hash,
create_token,
)
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages
from config import ( from config import (
@ -56,8 +69,12 @@ from config import (
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
AppConfig, AppConfig,
OAUTH_PROVIDERS, OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -453,6 +470,103 @@ async def get_app_latest_release_version():
) )
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session"
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
token = await client.authorize_access_token(request)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
email = user_data.get("email", "").lower()
if not email:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
user = Auths.insert_new_auth(
email=user_data.get("email", "").lower(),
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=user_data.get("picture", "/user.png"),
role=webui_app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json") @app.get("/manifest.json")
async def get_manifest_json(): async def get_manifest_json():
return { return {

View File

@ -259,7 +259,7 @@
<button <button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => { on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/google/login`; window.location.href = `${WEBUI_BASE_URL}/oauth/google/login`;
}} }}
> >
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 48 48" class="size-6 mr-3">
@ -284,7 +284,7 @@
<button <button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => { on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/microsoft/login`; window.location.href = `${WEBUI_BASE_URL}/oauth/microsoft/login`;
}} }}
> >
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 21 21" class="size-6 mr-3">
@ -309,7 +309,7 @@
<button <button
class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition" class="flex items-center px-6 border-2 dark:border-gray-800 duration-300 dark:bg-gray-900 hover:bg-gray-100 dark:hover:bg-gray-800 w-full rounded-2xl dark:text-white text-sm py-3 transition"
on:click={() => { on:click={() => {
window.location.href = `${WEBUI_API_BASE_URL}/auths/oauth/oidc/login`; window.location.href = `${WEBUI_BASE_URL}/oauth/oidc/login`;
}} }}
> >
<svg <svg