From af4f8aa589333b7330c31378b821148e6f6b6655 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Wed, 5 Jun 2024 19:21:42 +0100 Subject: [PATCH] feat: add WEBUI_SESSION_COOKIE_SAME_SITE for when open webui is embedded --- backend/config.py | 5 +++++ backend/main.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/backend/config.py b/backend/config.py index 2ae219e39..bd343c06f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -729,6 +729,11 @@ WEBUI_SECRET_KEY = os.environ.get( ), # DEPRECATED: remove at next major version ) +WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get( + "WEBUI_SESSION_COOKIE_SAME_SITE", + os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"), +) + if WEBUI_AUTH and WEBUI_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) diff --git a/backend/main.py b/backend/main.py index ac6b3a010..b0de5a099 100644 --- a/backend/main.py +++ b/backend/main.py @@ -73,6 +73,7 @@ from config import ( ENABLE_OAUTH_SIGNUP, OAUTH_MERGE_ACCOUNTS_BY_EMAIL, WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from utils.webhook import post_webhook @@ -507,7 +508,10 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items(): # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: app.add_middleware( - SessionMiddleware, secret_key=WEBUI_SECRET_KEY, session_cookie="oui-session" + SessionMiddleware, + secret_key=WEBUI_SECRET_KEY, + session_cookie="oui-session", + same_site=WEBUI_SESSION_COOKIE_SAME_SITE, ) @@ -524,7 +528,11 @@ async def oauth_callback(provider: str, request: Request): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) client = oauth.create_client(provider) - token = await client.authorize_access_token(request) + try: + token = await client.authorize_access_token(request) + except Exception as e: + log.error(f"OAuth callback error: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) user_data: UserInfo = token["userinfo"] sub = user_data.get("sub")