This commit is contained in:
Timothy J. Baek 2024-07-03 23:32:39 -07:00
parent 1b65df3acc
commit 864646094e
11 changed files with 789 additions and 616 deletions

View File

@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -110,14 +110,14 @@ class AuthsTable:
**{"id": id, "email": email, "password": password, "active": True} **{"id": id, "email": email, "password": password, "active": True}
) )
result = Auth(**auth.model_dump()) result = Auth(**auth.model_dump())
Session.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub
) )
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result and user: if result and user:
return user return user
@ -127,7 +127,7 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
try: try:
auth = Session.query(Auth).filter_by(email=email, active=True).first() auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth: if auth:
if verify_password(password, auth.password): if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
@ -154,7 +154,7 @@ class AuthsTable:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
try: try:
auth = Session.query(Auth).filter(email=email, active=True).first() auth = db.query(Auth).filter(email=email, active=True).first()
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
return user return user
@ -163,16 +163,14 @@ class AuthsTable:
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
result = ( result = db.query(Auth).filter_by(id=id).update({"password": new_password})
Session.query(Auth).filter_by(id=id).update({"password": new_password})
)
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
try: try:
result = Session.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({"email": email})
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
@ -183,7 +181,7 @@ class AuthsTable:
result = Users.delete_user_by_id(id) result = Users.delete_user_by_id(id)
if result: if result:
Session.query(Auth).filter_by(id=id).delete() db.query(Auth).filter_by(id=id).delete()
return True return True
else: else:

View File

@ -7,7 +7,7 @@ import time
from sqlalchemy import Column, String, BigInteger, Boolean, Text from sqlalchemy import Column, String, BigInteger, Boolean, Text
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
#################### ####################
@ -79,13 +79,17 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": ( "title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat" form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
), ),
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"created_at": int(time.time()), "created_at": int(time.time()),
@ -94,27 +98,31 @@ class ChatTable:
) )
result = Chat(**chat.model_dump()) result = Chat(**chat.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
return ChatModel.model_validate(result) if result else None return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try: try:
chat_obj = Session.get(Chat, id) with get_db() as db:
chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat) chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat" chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.updated_at = int(time.time()) chat_obj.updated_at = int(time.time())
Session.commit() db.commit()
Session.refresh(chat_obj) db.refresh(chat_obj)
return ChatModel.model_validate(chat_obj) return ChatModel.model_validate(chat_obj)
except Exception as e: except Exception as e:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share # Get the existing chat to share
chat = Session.get(Chat, chat_id) chat = db.get(Chat, chat_id)
# Check if the chat is already shared # Check if the chat is already shared
if chat.share_id: if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared") return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
@ -130,12 +138,12 @@ class ChatTable:
} }
) )
shared_result = Chat(**shared_chat.model_dump()) shared_result = Chat(**shared_chat.model_dump())
Session.add(shared_result) db.add(shared_result)
Session.commit() db.commit()
Session.refresh(shared_result) db.refresh(shared_result)
# Update the original chat with the share_id # Update the original chat with the share_id
result = ( result = (
Session.query(Chat) db.query(Chat)
.filter_by(id=chat_id) .filter_by(id=chat_id)
.update({"share_id": shared_chat.id}) .update({"share_id": shared_chat.id})
) )
@ -144,13 +152,15 @@ class ChatTable:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try: try:
with get_db() as db:
print("update_shared_chat_by_id") print("update_shared_chat_by_id")
chat = Session.get(Chat, chat_id) chat = db.get(Chat, chat_id)
print(chat) print(chat)
chat.title = chat.title chat.title = chat.title
chat.chat = chat.chat chat.chat = chat.chat
Session.commit() db.commit()
Session.refresh(chat) db.refresh(chat)
return self.get_chat_by_id(chat.share_id) return self.get_chat_by_id(chat.share_id)
except: except:
@ -158,7 +168,9 @@ class ChatTable:
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True return True
except: except:
return False return False
@ -167,27 +179,33 @@ class ChatTable:
self, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
chat = Session.get(Chat, id) with get_db() as db:
chat = db.get(Chat, id)
chat.share_id = share_id chat.share_id = share_id
Session.commit() db.commit()
Session.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Session.get(Chat, id) with get_db() as db:
chat = db.get(Chat, id)
chat.archived = not chat.archived chat.archived = not chat.archived
Session.commit() db.commit()
Session.refresh(chat) db.refresh(chat)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True return True
except: except:
return False return False
@ -195,8 +213,10 @@ class ChatTable:
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_db() as db:
all_chats = ( all_chats = (
Session.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
@ -211,7 +231,8 @@ class ChatTable:
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
query = Session.query(Chat).filter_by(user_id=user_id) with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived: if not include_archived:
query = query.filter_by(archived=False) query = query.filter_by(archived=False)
all_chats = ( all_chats = (
@ -224,8 +245,11 @@ class ChatTable:
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
with get_db() as db:
all_chats = ( all_chats = (
Session.query(Chat) db.query(Chat)
.filter(Chat.id.in_(chat_ids)) .filter(Chat.id.in_(chat_ids))
.filter_by(archived=False) .filter_by(archived=False)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
@ -235,14 +259,18 @@ class ChatTable:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Session.get(Chat, id) with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = Session.query(Chat).filter_by(share_id=id).first() with get_db() as db:
chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(id) return self.get_chat_by_id(id)
@ -253,30 +281,38 @@ class ChatTable:
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat) return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
with get_db() as db:
all_chats = ( all_chats = (
Session.query(Chat) db.query(Chat)
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_db() as db:
all_chats = ( all_chats = (
Session.query(Chat) db.query(Chat)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
return [ChatModel.model_validate(chat) for chat in all_chats] return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_db() as db:
all_chats = ( all_chats = (
Session.query(Chat) db.query(Chat)
.filter_by(user_id=user_id, archived=True) .filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
) )
@ -284,7 +320,9 @@ class ChatTable:
def delete_chat_by_id(self, id: str) -> bool: def delete_chat_by_id(self, id: str) -> bool:
try: try:
Session.query(Chat).filter_by(id=id).delete() with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
@ -292,7 +330,9 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
Session.query(Chat).filter_by(id=id, user_id=user_id).delete() with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
@ -300,19 +340,25 @@ class ChatTable:
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id) self.delete_shared_chats_by_user_id(user_id)
Session.query(Chat).filter_by(user_id=user_id).delete() db.query(Chat).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
return True return True
except: except:

View File

@ -5,7 +5,7 @@ import logging
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
import json import json
@ -74,6 +74,8 @@ class DocumentsTable:
def insert_new_doc( def insert_new_doc(
self, user_id: str, form_data: DocumentForm self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel( document = DocumentModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -84,9 +86,9 @@ class DocumentsTable:
try: try:
result = Document(**document.model_dump()) result = Document(**document.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return DocumentModel.model_validate(result) return DocumentModel.model_validate(result)
else: else:
@ -96,28 +98,34 @@ class DocumentsTable:
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
document = Session.query(Document).filter_by(name=name).first() with get_db() as db:
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None return DocumentModel.model_validate(document) if document else None
except: except:
return None return None
def get_docs(self) -> List[DocumentModel]: def get_docs(self) -> List[DocumentModel]:
with get_db() as db:
return [ return [
DocumentModel.model_validate(doc) for doc in Session.query(Document).all() DocumentModel.model_validate(doc) for doc in db.query(Document).all()
] ]
def update_doc_by_name( def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
Session.query(Document).filter_by(name=name).update( with get_db() as db:
db.query(Document).filter_by(name=name).update(
{ {
"title": form_data.title, "title": form_data.title,
"name": form_data.name, "name": form_data.name,
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
Session.commit() db.commit()
return self.get_doc_by_name(form_data.name) return self.get_doc_by_name(form_data.name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -131,13 +139,15 @@ class DocumentsTable:
doc_content = json.loads(doc.content if doc.content else "{}") doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated} doc_content = {**doc_content, **updated}
Session.query(Document).filter_by(name=name).update( with get_db() as db:
db.query(Document).filter_by(name=name).update(
{ {
"content": json.dumps(doc_content), "content": json.dumps(doc_content),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
Session.commit() db.commit()
return self.get_doc_by_name(name) return self.get_doc_by_name(name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@ -145,7 +155,9 @@ class DocumentsTable:
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
Session.query(Document).filter_by(name=name).delete() with get_db() as db:
db.query(Document).filter_by(name=name).delete()
return True return True
except: except:
return False return False

View File

@ -5,7 +5,7 @@ import logging
from sqlalchemy import Column, String, BigInteger, Text from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import JSONField, Base, Session from apps.webui.internal.db import JSONField, Base, get_db
import json import json
@ -61,6 +61,8 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -71,9 +73,9 @@ class FilesTable:
try: try:
result = File(**file.model_dump()) result = File(**file.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return FileModel.model_validate(result) return FileModel.model_validate(result)
else: else:
@ -83,25 +85,35 @@ class FilesTable:
return None return None
def get_file_by_id(self, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db:
try: try:
file = Session.get(File, id) file = db.get(File, id)
return FileModel.model_validate(file) return FileModel.model_validate(file)
except: except:
return None return None
def get_files(self) -> List[FileModel]: def get_files(self) -> List[FileModel]:
return [FileModel.model_validate(file) for file in Session.query(File).all()] with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try: try:
Session.query(File).filter_by(id=id).delete() db.query(File).filter_by(id=id).delete()
return True return True
except: except:
return False return False
def delete_all_files(self) -> bool: def delete_all_files(self) -> bool:
with get_db() as db:
try: try:
Session.query(File).delete() db.query(File).delete()
return True return True
except: except:
return False return False

View File

@ -5,7 +5,7 @@ import logging
from sqlalchemy import Column, String, Text, BigInteger, Boolean from sqlalchemy import Column, String, Text, BigInteger, Boolean
from apps.webui.internal.db import JSONField, Base, Session from apps.webui.internal.db import JSONField, Base, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -91,6 +91,7 @@ class FunctionsTable:
def insert_new_function( def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -102,10 +103,11 @@ class FunctionsTable:
) )
try: try:
with get_db() as db:
result = Function(**function.model_dump()) result = Function(**function.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return FunctionModel.model_validate(result) return FunctionModel.model_validate(result)
else: else:
@ -116,50 +118,60 @@ class FunctionsTable:
def get_function_by_id(self, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
function = Session.get(Function, id) with get_db() as db:
function = db.get(Function, id)
return FunctionModel.model_validate(function) return FunctionModel.model_validate(function)
except: except:
return None return None
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in Session.query(Function).filter_by(is_active=True).all() for function in db.query(Function).filter_by(is_active=True).all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in Session.query(Function).all() for function in db.query(Function).all()
] ]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> List[FunctionModel]:
with get_db() as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in Session.query(Function) for function in db.query(Function)
.filter_by(type=type, is_active=True) .filter_by(type=type, is_active=True)
.all() .all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in Session.query(Function).filter_by(type=type).all() for function in db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions(self) -> List[FunctionModel]: def get_global_filter_functions(self) -> List[FunctionModel]:
with get_db() as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in Session.query(Function) for function in db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True) .filter_by(type="filter", is_active=True, is_global=True)
.all() .all()
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
try: try:
function = Session.get(Function, id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
@ -168,12 +180,14 @@ class FunctionsTable:
def update_function_valves_by_id( def update_function_valves_by_id(
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
with get_db() as db:
try: try:
function = Session.get(Function, id) function = db.get(Function, id)
function.valves = valves function.valves = valves
function.updated_at = int(time.time()) function.updated_at = int(time.time())
Session.commit() db.commit()
Session.refresh(function) db.refresh(function)
return self.get_function_by_id(id) return self.get_function_by_id(id)
except: except:
return None return None
@ -181,6 +195,7 @@ class FunctionsTable:
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump()
@ -199,6 +214,7 @@ class FunctionsTable:
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict self, id: str, user_id: str, valves: dict
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump()
@ -220,34 +236,40 @@ class FunctionsTable:
return None return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
with get_db() as db:
try: try:
Session.query(Function).filter_by(id=id).update( db.query(Function).filter_by(id=id).update(
{ {
**updated, **updated,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
Session.commit() db.commit()
return self.get_function_by_id(id) return self.get_function_by_id(id)
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db:
try: try:
Session.query(Function).update( db.query(Function).update(
{ {
"is_active": False, "is_active": False,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
Session.commit() db.commit()
return True return True
except: except:
return None return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
with get_db() as db:
try: try:
Session.query(Function).filter_by(id=id).delete() db.query(Function).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -3,7 +3,7 @@ from typing import List, Union, Optional
from sqlalchemy import Column, String, BigInteger, Text from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
import time import time
import uuid import uuid
@ -45,6 +45,8 @@ class MemoriesTable:
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
memory = MemoryModel( memory = MemoryModel(
@ -57,9 +59,9 @@ class MemoriesTable:
} }
) )
result = Memory(**memory.model_dump()) result = Memory(**memory.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return MemoryModel.model_validate(result) return MemoryModel.model_validate(result)
else: else:
@ -70,54 +72,68 @@ class MemoriesTable:
id: str, id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
with get_db() as db:
try: try:
Session.query(Memory).filter_by(id=id).update( db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())} {"content": content, "updated_at": int(time.time())}
) )
Session.commit() db.commit()
return self.get_memory_by_id(id) return self.get_memory_by_id(id)
except: except:
return None return None
def get_memories(self) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
with get_db() as db:
try: try:
memories = Session.query(Memory).all() memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
with get_db() as db:
try: try:
memories = Session.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db:
try: try:
memory = Session.get(Memory, id) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
except: except:
return None return None
def delete_memory_by_id(self, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db:
try: try:
Session.query(Memory).filter_by(id=id).delete() db.query(Memory).filter_by(id=id).delete()
return True return True
except: except:
return False return False
def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
try: try:
Session.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
try: try:
Session.query(Memory).filter_by(id=id, user_id=user_id).delete() db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True
except: except:
return False return False

View File

@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, Session from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -126,10 +126,13 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return ModelModel.model_validate(result) return ModelModel.model_validate(result)
@ -140,24 +143,28 @@ class ModelsTable:
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ with get_db() as db:
ModelModel.model_validate(model) for model in Session.query(Model).all()
] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
model = Session.get(Model, id) with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
with get_db() as db:
# update only the fields that are present in the model # update only the fields that are present in the model
model = Session.query(Model).get(id) model = db.query(Model).get(id)
model.update(**model.model_dump()) model.update(**model.model_dump())
Session.commit() db.commit()
Session.refresh(model) db.refresh(model)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
@ -166,7 +173,9 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
Session.query(Model).filter_by(id=id).delete() with get_db() as db:
db.query(Model).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -4,7 +4,7 @@ import time
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
import json import json
@ -60,10 +60,12 @@ class PromptsTable:
) )
try: try:
with get_db() as db:
result = Prompt(**prompt.dict()) result = Prompt(**prompt.dict())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return PromptModel.model_validate(result) return PromptModel.model_validate(result)
else: else:
@ -73,32 +75,40 @@ class PromptsTable:
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
prompt = Session.query(Prompt).filter_by(command=command).first() with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except: except:
return None return None
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
with get_db() as db:
return [ return [
PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
prompt = Session.query(Prompt).filter_by(command=command).first() with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title prompt.title = form_data.title
prompt.content = form_data.content prompt.content = form_data.content
prompt.timestamp = int(time.time()) prompt.timestamp = int(time.time())
Session.commit() db.commit()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except: except:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
Session.query(Prompt).filter_by(command=command).delete() with get_db() as db:
db.query(Prompt).filter_by(command=command).delete()
return True return True
except: except:
return False return False

View File

@ -8,7 +8,7 @@ import logging
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, Session from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
@ -79,13 +79,15 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
with get_db() as db:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag(**tag.model_dump()) result = Tag(**tag.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return TagModel.model_validate(result) return TagModel.model_validate(result)
else: else:
@ -97,7 +99,8 @@ class TagTable:
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
tag = Session.query(Tag).filter(name=name, user_id=user_id).first() with get_db() as db:
tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
@ -120,10 +123,11 @@ class TagTable:
} }
) )
try: try:
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump()) result = ChatIdTag(**chatIdTag.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return ChatIdTagModel.model_validate(result) return ChatIdTagModel.model_validate(result)
else: else:
@ -132,10 +136,11 @@ class TagTable:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
with get_db() as db:
tag_names = [ tag_names = [
chat_id_tag.tag_name chat_id_tag.tag_name
for chat_id_tag in ( for chat_id_tag in (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all() .all()
@ -145,7 +150,7 @@ class TagTable:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
Session.query(Tag) db.query(Tag)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names)) .filter(Tag.name.in_(tag_names))
.all() .all()
@ -155,10 +160,12 @@ class TagTable:
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
with get_db() as db:
tag_names = [ tag_names = [
chat_id_tag.tag_name chat_id_tag.tag_name
for chat_id_tag in ( for chat_id_tag in (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id) .filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all() .all()
@ -168,7 +175,7 @@ class TagTable:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in ( for tag in (
Session.query(Tag) db.query(Tag)
.filter_by(user_id=user_id) .filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names)) .filter(Tag.name.in_(tag_names))
.all() .all()
@ -178,10 +185,12 @@ class TagTable:
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> List[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
with get_db() as db:
return [ return [
ChatIdTagModel.model_validate(chat_id_tag) ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in ( for chat_id_tag in (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name) .filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc()) .order_by(ChatIdTag.timestamp.desc())
.all() .all()
@ -191,26 +200,31 @@ class TagTable:
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
with get_db() as db:
return ( return (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id) .filter_by(tag_name=tag_name, user_id=user_id)
.count() .count()
) )
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
with get_db() as db:
res = ( res = (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id) .filter_by(tag_name=tag_name, user_id=user_id)
.delete() .delete()
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
Session.commit() db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
@ -220,18 +234,22 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
with get_db() as db:
res = ( res = (
Session.query(ChatIdTag) db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete() .delete()
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
Session.commit() db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0: if tag_count == 0:
# Remove tag item from Tag col as well # Remove tag item from Tag col as well
Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:

View File

@ -4,7 +4,7 @@ import time
import logging import logging
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, Session from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
@ -83,6 +83,9 @@ class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
with get_db() as db:
tool = ToolModel( tool = ToolModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -95,9 +98,9 @@ class ToolsTable:
try: try:
result = Tool(**tool.model_dump()) result = Tool(**tool.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return ToolModel.model_validate(result) return ToolModel.model_validate(result)
else: else:
@ -108,17 +111,21 @@ class ToolsTable:
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
tool = Session.get(Tool, id) with get_db() as db:
tool = db.get(Tool, id)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Session.get(Tool, id) with get_db() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
@ -126,10 +133,12 @@ class ToolsTable:
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
Session.query(Tool).filter_by(id=id).update( with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())} {"valves": valves, "updated_at": int(time.time())}
) )
Session.commit() db.commit()
return self.get_tool_by_id(id) return self.get_tool_by_id(id)
except: except:
return None return None
@ -177,18 +186,20 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
tool = Session.get(Tool, id) with get_db() as db:
tool = db.get(Tool, id)
tool.update(**updated) tool.update(**updated)
tool.updated_at = int(time.time()) tool.updated_at = int(time.time())
Session.commit() db.commit()
Session.refresh(tool) db.refresh(tool)
return ToolModel.model_validate(tool) return ToolModel.model_validate(tool)
except: except:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
Session.query(Tool).filter_by(id=id).delete() with get_db() as db:
db.query(Tool).filter_by(id=id).delete()
return True return True
except: except:
return False return False

View File

@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField, Session from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
@ -88,6 +88,7 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, "id": id,
@ -102,9 +103,9 @@ class UsersTable:
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
Session.add(result) db.add(result)
Session.commit() db.commit()
Session.refresh(result) db.refresh(result)
if result: if result:
return user return user
else: else:
@ -112,56 +113,66 @@ class UsersTable:
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
try: try:
user = Session.query(User).filter_by(id=id).first() with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception as e: except Exception as e:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
user = Session.query(User).filter_by(api_key=api_key).first() with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = Session.query(User).filter_by(email=email).first() with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
user = Session.query(User).filter_by(oauth_sub=sub).first() with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
with get_db() as db:
users = ( users = (
Session.query(User) db.query(User)
# .offset(skip).limit(limit) # .offset(skip).limit(limit)
.all() .all()
) )
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return Session.query(User).count() with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
user = Session.query(User).order_by(User.created_at).first() with get_db() as db:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
Session.query(User).filter_by(id=id).update({"role": role}) with get_db() as db:
Session.commit() db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = Session.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
@ -170,24 +181,27 @@ class UsersTable:
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
Session.query(User).filter_by(id=id).update( with get_db() as db:
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url} {"profile_image_url": profile_image_url}
) )
Session.commit() db.commit()
user = Session.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
Session.query(User).filter_by(id=id).update( with get_db() as db:
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())} {"last_active_at": int(time.time())}
) )
Session.commit() db.commit()
user = Session.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
@ -196,19 +210,21 @@ class UsersTable:
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
user = Session.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
Session.query(User).filter_by(id=id).update(updated) with get_db() as db:
Session.commit() db.query(User).filter_by(id=id).update(updated)
db.commit()
user = Session.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
# return UserModel(**user.dict()) # return UserModel(**user.dict())
except Exception as e: except Exception as e:
@ -220,9 +236,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id) result = Chats.delete_chats_by_user_id(id)
if result: if result:
with get_db() as db:
# Delete User # Delete User
Session.query(User).filter_by(id=id).delete() db.query(User).filter_by(id=id).delete()
Session.commit() db.commit()
return True return True
else: else:
@ -232,15 +249,17 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try: try:
result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) with get_db() as db:
Session.commit() result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
user = Session.query(User).filter_by(id=id).first() with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key return user.api_key
except Exception as e: except Exception as e:
return None return None