diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 400ddac0d..761a11cc9 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -26,6 +26,8 @@ app = FastAPI() origins = ["*"] app.state.ENABLE_SIGNUP = ENABLE_SIGNUP +app.state.JWT_EXPIRES_IN = "-1" + app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 7ccef6300..3db2d0ad2 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, status from pydantic import BaseModel import time import uuid +import re from apps.web.models.auths import ( SigninForm, @@ -25,7 +26,7 @@ from utils.utils import ( get_admin_user, create_token, ) -from utils.misc import get_gravatar_url, validate_email_format +from utils.misc import parse_duration, validate_email_format from constants import ERROR_MESSAGES router = APIRouter() @@ -95,10 +96,13 @@ async def update_password( @router.post("/signin", response_model=SigninResponse) -async def signin(form_data: SigninForm): +async def signin(request: Request, form_data: SigninForm): user = Auths.authenticate_user(form_data.email.lower(), form_data.password) if user: - token = create_token(data={"id": user.id}) + token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + ) return { "token": token, @@ -145,7 +149,10 @@ async def signup(request: Request, form_data: SignupForm): ) if user: - token = create_token(data={"id": user.id}) + token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), + ) # response.set_cookie(key='token', value=token, httponly=True) return { @@ -200,3 +207,33 @@ async def update_default_user_role( if form_data.role in ["pending", "user", "admin"]: request.app.state.DEFAULT_USER_ROLE = form_data.role return request.app.state.DEFAULT_USER_ROLE + + +############################ +# JWT Expiration +############################ + + +@router.get("/token/expires") +async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): + return request.app.state.JWT_EXPIRES_IN + + +class UpdateJWTExpiresDurationForm(BaseModel): + duration: str + + +@router.post("/token/expires/update") +async def update_token_expires_duration( + request: Request, + form_data: UpdateJWTExpiresDurationForm, + user=Depends(get_admin_user), +): + pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$" + + # Check if the input string matches the pattern + if re.match(pattern, form_data.duration): + request.app.state.JWT_EXPIRES_IN = form_data.duration + return request.app.state.JWT_EXPIRES_IN + else: + return request.app.state.JWT_EXPIRES_IN diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5e9d5876e..98528c400 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,6 +1,8 @@ from pathlib import Path import hashlib import re +from datetime import timedelta +from typing import Optional def get_gravatar_url(email): @@ -76,3 +78,34 @@ def extract_folders_after_data_docs(path): tags.append("/".join(folders[: idx + 1])) return tags + + +def parse_duration(duration: str) -> Optional[timedelta]: + if duration == "-1" or duration == "0": + return None + + # Regular expression to find number and unit pairs + pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)" + matches = re.findall(pattern, duration) + + if not matches: + raise ValueError("Invalid duration string") + + total_duration = timedelta() + + for number, _, unit in matches: + number = float(number) + if unit == "ms": + total_duration += timedelta(milliseconds=number) + elif unit == "s": + total_duration += timedelta(seconds=number) + elif unit == "m": + total_duration += timedelta(minutes=number) + elif unit == "h": + total_duration += timedelta(hours=number) + elif unit == "d": + total_duration += timedelta(days=number) + elif unit == "w": + total_duration += timedelta(weeks=number) + + return total_duration diff --git a/src/lib/apis/auths/index.ts b/src/lib/apis/auths/index.ts index 078589984..169998726 100644 --- a/src/lib/apis/auths/index.ts +++ b/src/lib/apis/auths/index.ts @@ -261,3 +261,60 @@ export const toggleSignUpEnabledStatus = async (token: string) => { return res; }; + +export const getJWTExpiresDuration = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/token/expires`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateJWTExpiresDuration = async (token: string, duration: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/auths/token/expires/update`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + duration: duration + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Settings/General.svelte b/src/lib/components/admin/Settings/General.svelte index 48ed41e74..b8f71d30f 100644 --- a/src/lib/components/admin/Settings/General.svelte +++ b/src/lib/components/admin/Settings/General.svelte @@ -1,15 +1,18 @@ @@ -29,6 +37,7 @@ class="flex flex-col h-full justify-between space-y-3 text-sm" on:submit|preventDefault={() => { // console.log('submit'); + updateJWTExpiresDurationHandler(JWTExpiresIn); saveHandler(); }} > @@ -94,6 +103,29 @@ + +