mirror of
https://github.com/open-webui/open-webui
synced 2024-11-29 23:41:50 +00:00
240 lines
6.6 KiB
Python
240 lines
6.6 KiB
Python
from fastapi import Response, Request
|
|
from fastapi import Depends, FastAPI, HTTPException, status
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Union
|
|
|
|
from fastapi import APIRouter, status
|
|
from pydantic import BaseModel
|
|
import time
|
|
import uuid
|
|
import re
|
|
|
|
from apps.web.models.auths import (
|
|
SigninForm,
|
|
SignupForm,
|
|
UpdateProfileForm,
|
|
UpdatePasswordForm,
|
|
UserResponse,
|
|
SigninResponse,
|
|
Auths,
|
|
)
|
|
from apps.web.models.users import Users
|
|
|
|
from utils.utils import (
|
|
get_password_hash,
|
|
get_current_user,
|
|
get_admin_user,
|
|
create_token,
|
|
)
|
|
from utils.misc import parse_duration, validate_email_format
|
|
from constants import ERROR_MESSAGES
|
|
|
|
router = APIRouter()
|
|
|
|
############################
|
|
# GetSessionUser
|
|
############################
|
|
|
|
|
|
@router.get("/", response_model=UserResponse)
|
|
async def get_session_user(user=Depends(get_current_user)):
|
|
return {
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
"profile_image_url": user.profile_image_url,
|
|
}
|
|
|
|
|
|
############################
|
|
# Update Profile
|
|
############################
|
|
|
|
|
|
@router.post("/update/profile", response_model=UserResponse)
|
|
async def update_profile(
|
|
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
|
|
):
|
|
if session_user:
|
|
user = Users.update_user_by_id(
|
|
session_user.id,
|
|
{"profile_image_url": form_data.profile_image_url, "name": form_data.name},
|
|
)
|
|
if user:
|
|
return user
|
|
else:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
|
|
else:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
############################
|
|
# Update Password
|
|
############################
|
|
|
|
|
|
@router.post("/update/password", response_model=bool)
|
|
async def update_password(
|
|
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
|
):
|
|
if session_user:
|
|
user = Auths.authenticate_user(session_user.email, form_data.password)
|
|
|
|
if user:
|
|
hashed = get_password_hash(form_data.new_password)
|
|
return Auths.update_user_password_by_id(user.id, hashed)
|
|
else:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
|
|
else:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
############################
|
|
# SignIn
|
|
############################
|
|
|
|
|
|
@router.post("/signin", response_model=SigninResponse)
|
|
async def signin(request: Request, form_data: SigninForm):
|
|
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
|
|
if user:
|
|
token = create_token(
|
|
data={"id": user.id},
|
|
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
|
|
)
|
|
|
|
return {
|
|
"token": token,
|
|
"token_type": "Bearer",
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
"profile_image_url": user.profile_image_url,
|
|
}
|
|
else:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
|
|
############################
|
|
# SignUp
|
|
############################
|
|
|
|
|
|
@router.post("/signup", response_model=SigninResponse)
|
|
async def signup(request: Request, form_data: SignupForm):
|
|
if not request.app.state.ENABLE_SIGNUP:
|
|
raise HTTPException(
|
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
)
|
|
|
|
if not validate_email_format(form_data.email.lower()):
|
|
raise HTTPException(
|
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
|
)
|
|
|
|
if Users.get_user_by_email(form_data.email.lower()):
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
|
try:
|
|
role = (
|
|
"admin"
|
|
if Users.get_num_users() == 0
|
|
else request.app.state.DEFAULT_USER_ROLE
|
|
)
|
|
hashed = get_password_hash(form_data.password)
|
|
user = Auths.insert_new_auth(
|
|
form_data.email.lower(), hashed, form_data.name, role
|
|
)
|
|
|
|
if user:
|
|
token = create_token(
|
|
data={"id": user.id},
|
|
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
|
|
)
|
|
# response.set_cookie(key='token', value=token, httponly=True)
|
|
|
|
return {
|
|
"token": token,
|
|
"token_type": "Bearer",
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"name": user.name,
|
|
"role": user.role,
|
|
"profile_image_url": user.profile_image_url,
|
|
}
|
|
else:
|
|
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
|
|
except Exception as err:
|
|
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
|
|
|
|
|
############################
|
|
# ToggleSignUp
|
|
############################
|
|
|
|
|
|
@router.get("/signup/enabled", response_model=bool)
|
|
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
|
|
return request.app.state.ENABLE_SIGNUP
|
|
|
|
|
|
@router.get("/signup/enabled/toggle", response_model=bool)
|
|
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
|
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
|
|
return request.app.state.ENABLE_SIGNUP
|
|
|
|
|
|
############################
|
|
# Default User Role
|
|
############################
|
|
|
|
|
|
@router.get("/signup/user/role")
|
|
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
|
|
return request.app.state.DEFAULT_USER_ROLE
|
|
|
|
|
|
class UpdateRoleForm(BaseModel):
|
|
role: str
|
|
|
|
|
|
@router.post("/signup/user/role")
|
|
async def update_default_user_role(
|
|
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
|
|
):
|
|
if form_data.role in ["pending", "user", "admin"]:
|
|
request.app.state.DEFAULT_USER_ROLE = form_data.role
|
|
return request.app.state.DEFAULT_USER_ROLE
|
|
|
|
|
|
############################
|
|
# JWT Expiration
|
|
############################
|
|
|
|
|
|
@router.get("/token/expires")
|
|
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
|
|
return request.app.state.JWT_EXPIRES_IN
|
|
|
|
|
|
class UpdateJWTExpiresDurationForm(BaseModel):
|
|
duration: str
|
|
|
|
|
|
@router.post("/token/expires/update")
|
|
async def update_token_expires_duration(
|
|
request: Request,
|
|
form_data: UpdateJWTExpiresDurationForm,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
|
|
|
# Check if the input string matches the pattern
|
|
if re.match(pattern, form_data.duration):
|
|
request.app.state.JWT_EXPIRES_IN = form_data.duration
|
|
return request.app.state.JWT_EXPIRES_IN
|
|
else:
|
|
return request.app.state.JWT_EXPIRES_IN
|