feat(sqlalchemy): remove session reference from router

This commit is contained in:
Jonathan Rohde
2024-06-21 14:58:57 +02:00
parent df09d0830a
commit bee835cb65
34 changed files with 1231 additions and 1211 deletions

View File

@@ -9,7 +9,6 @@ import time
import uuid
import logging
from apps.webui.internal.db import get_db
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
@@ -42,9 +41,9 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
):
return Users.get_users(db, skip, limit)
return Users.get_users(skip, limit)
############################
@@ -72,11 +71,11 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
):
if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
return Users.update_user_role_by_id(db, form_data.id, form_data.role)
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -91,9 +90,9 @@ async def update_user_role(
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
return user.settings
else:
@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
return user.info
else:
@@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(
db, user.id, {"info": {**user.info, **form_data}}
user.id, {"info": {**user.info, **form_data}}
)
if user:
return user.info
@@ -184,14 +183,14 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
user_id: str, user=Depends(get_verified_user)
):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(db, chat_id)
chat = Chats.get_chat_by_id(chat_id)
if chat:
user_id = chat.user_id
else:
@@ -200,7 +199,7 @@ async def get_user_by_id(
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(db, user_id)
user = Users.get_user_by_id(user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
@@ -221,13 +220,12 @@ async def update_user_by_id(
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
db=Depends(get_db),
):
user = Users.get_user_by_id(db, user_id)
user = Users.get_user_by_id(user_id)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(db, form_data.email.lower())
email_user = Users.get_user_by_email(form_data.email.lower())
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -237,11 +235,10 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(db, user_id, hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(db, user_id, form_data.email.lower())
Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
db,
user_id,
{
"name": form_data.name,
@@ -271,10 +268,10 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
user_id: str, user=Depends(get_admin_user)
):
if user.id != user_id:
result = Auths.delete_auth_by_id(db, user_id)
result = Auths.delete_auth_by_id(user_id)
if result:
return True