diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 560d9a686..48c8b543a 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -110,14 +110,14 @@ class AuthsTable: **{"id": id, "email": email, "password": password, "active": True} ) result = Auth(**auth.model_dump()) - Session.add(result) + db.add(result) user = Users.insert_new_user( id, name, email, profile_image_url, role, oauth_sub ) - Session.commit() - Session.refresh(result) + db.commit() + db.refresh(result) if result and user: return user @@ -127,7 +127,7 @@ class AuthsTable: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = Session.query(Auth).filter_by(email=email, active=True).first() + auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: if verify_password(password, auth.password): user = Users.get_user_by_id(auth.id) @@ -154,7 +154,7 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = Session.query(Auth).filter(email=email, active=True).first() + auth = db.query(Auth).filter(email=email, active=True).first() if auth: user = Users.get_user_by_id(auth.id) return user @@ -163,16 +163,14 @@ class AuthsTable: def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = ( - Session.query(Auth).filter_by(id=id).update({"password": new_password}) - ) + result = db.query(Auth).filter_by(id=id).update({"password": new_password}) return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = Session.query(Auth).filter_by(id=id).update({"email": email}) + result = db.query(Auth).filter_by(id=id).update({"email": email}) return True if result == 1 else False except: return False @@ -183,7 +181,7 @@ class AuthsTable: result = Users.delete_user_by_id(id) if result: - Session.query(Auth).filter_by(id=id).delete() + db.query(Auth).filter_by(id=id).delete() return True else: diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d6829ee7b..8d2e6b104 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import time from sqlalchemy import Column, String, BigInteger, Boolean, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db #################### @@ -79,87 +79,99 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + with get_db() as db: - result = Chat(**chat.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - return ChatModel.model_validate(result) if result else None + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - chat_obj = Session.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - Session.commit() - Session.refresh(chat_obj) + with get_db() as db: - return ChatModel.model_validate(chat_obj) + chat_obj = db.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + db.commit() + db.refresh(chat_obj) + + return ChatModel.model_validate(chat_obj) except Exception as e: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - # Get the existing chat to share - chat = Session.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - Session.add(shared_result) - Session.commit() - Session.refresh(shared_result) - # Update the original chat with the share_id - result = ( - Session.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + with get_db() as db: - return shared_chat if (shared_result and result) else None + # Get the existing chat to share + chat = db.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) + # Update the original chat with the share_id + result = ( + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) + + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: - print("update_shared_chat_by_id") - chat = Session.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - Session.commit() - Session.refresh(chat) + with get_db() as db: - return self.get_chat_by_id(chat.share_id) + print("update_shared_chat_by_id") + chat = db.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + db.commit() + db.refresh(chat) + + return self.get_chat_by_id(chat.share_id) except: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + return True except: return False @@ -167,42 +179,50 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.share_id = share_id - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.share_id = share_id + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.archived = not chat.archived - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.archived = not chat.archived + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + return True except: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, @@ -211,110 +231,136 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - query = Session.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(share_id=id).first() + with get_db() as db: - if chat: - return self.get_chat_by_id(id) - else: - return None + chat = db.query(Chat).filter_by(share_id=id).first() + + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - Session.query(Chat).filter_by(id=id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - Session.query(Chat).filter_by(id=id, user_id=user_id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id, user_id=user_id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - self.delete_shared_chats_by_user_id(user_id) - Session.query(Chat).filter_by(user_id=user_id).delete() - return True + with get_db() as db: + + self.delete_shared_chats_by_user_id(user_id) + + db.query(Chat).filter_by(user_id=user_id).delete() + return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + with get_db() as db: - return True + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + + return True except: return False diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 1b69d44a5..16145c4ac 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -74,51 +74,59 @@ class DocumentsTable: def insert_new_doc( self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) + with get_db() as db: - try: - result = Document(**document.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: + document = DocumentModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "timestamp": int(time.time()), + } + ) + + try: + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None + except: return None - except: - return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - document = Session.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + with get_db() as db: + + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - return [ - DocumentModel.model_validate(doc) for doc in Session.query(Document).all() - ] + with get_db() as db: + + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - Session.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(form_data.name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -131,22 +139,26 @@ class DocumentsTable: doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated} - Session.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - Session.query(Document).filter_by(name=name).delete() - return True + with get_db() as db: + + db.query(Document).filter_by(name=name).delete() + return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index ce904215d..58058e907 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db import json @@ -61,50 +61,62 @@ class FileForm(BaseModel): class FilesTable: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - file = FileModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - } - ) + with get_db() as db: - try: - result = File(**file.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FileModel.model_validate(result) - else: + file = FileModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + } + ) + + try: + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_file_by_id(self, id: str) -> Optional[FileModel]: - try: - file = Session.get(File, id) - return FileModel.model_validate(file) - except: - return None + with get_db() as db: + + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except: + return None def get_files(self) -> List[FileModel]: - return [FileModel.model_validate(file) for file in Session.query(File).all()] + with get_db() as db: + + return [FileModel.model_validate(file) for file in db.query(File).all()] def delete_file_by_id(self, id: str) -> bool: - try: - Session.query(File).filter_by(id=id).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).filter_by(id=id).delete() + return True + except: + return False def delete_all_files(self) -> bool: - try: - Session.query(File).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).delete() + return True + except: + return False Files = FilesTable() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 64ed4f3cc..5718833d3 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db from apps.webui.models.users import Users import json @@ -91,6 +91,7 @@ class FunctionsTable: def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: + function = FunctionModel( **{ **form_data.model_dump(), @@ -102,85 +103,99 @@ class FunctionsTable: ) try: - result = Function(**function.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + with get_db() as db: + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - function = Session.get(Function, id) - return FunctionModel.model_validate(function) + with get_db() as db: + + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(is_active=True).all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(is_active=True).all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(type=type).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(type=type).all() + ] def get_global_filter_functions(self) -> List[FunctionModel]: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() - ] + with get_db() as db: + + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="filter", is_active=True, is_global=True) + .all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: - try: - function = Session.get(Function, id) - return function.valves if function.valves else {} - except Exception as e: - print(f"An error occurred: {e}") - return None + with get_db() as db: + + try: + function = db.get(Function, id) + return function.valves if function.valves else {} + except Exception as e: + print(f"An error occurred: {e}") + return None def update_function_valves_by_id( self, id: str, valves: dict ) -> Optional[FunctionValves]: - try: - function = Session.get(Function, id) - function.valves = valves - function.updated_at = int(time.time()) - Session.commit() - Session.refresh(function) - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + function = db.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + db.commit() + db.refresh(function) + return self.get_function_by_id(id) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -199,6 +214,7 @@ class FunctionsTable: def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -220,37 +236,43 @@ class FunctionsTable: return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - try: - Session.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - Session.commit() - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_function_by_id(id) + except: + return None def deactivate_all_functions(self) -> Optional[bool]: - try: - Session.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - Session.commit() - return True - except: - return None + with get_db() as db: + + try: + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + db.commit() + return True + except: + return None def delete_function_by_id(self, id: str) -> bool: - try: - Session.query(Function).filter_by(id=id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).delete() + return True + except: + return False Functions = FunctionsTable() diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 1f03318fd..662bbedfe 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import time import uuid @@ -45,82 +45,98 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - id = str(uuid.uuid4()) - memory = MemoryModel( - **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) - result = Memory(**memory.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + with get_db() as db: + id = str(uuid.uuid4()) + + memory = MemoryModel( + **{ + "id": id, + "user_id": user_id, + "content": content, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + result = Memory(**memory.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, id: str, content: str, ) -> Optional[MemoryModel]: - try: - Session.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_memory_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + db.commit() + return self.get_memory_by_id(id) + except: + return None def get_memories(self) -> List[MemoryModel]: - try: - memories = Session.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: - try: - memories = Session.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - try: - memory = Session.get(Memory, id) - return MemoryModel.model_validate(memory) - except: - return None + with get_db() as db: + + try: + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) + except: + return None def delete_memory_by_id(self, id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id).delete() - return True + with get_db() as db: - except: - return False + try: + db.query(Memory).filter_by(id=id).delete() + return True + + except: + return False def delete_memories_by_user_id(self, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(user_id=user_id).delete() + return True + except: + return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id, user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id, user_id=user_id).delete() + return True + except: + return False Memories = MemoriesTable() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 6543edefc..c95c36c7d 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -126,39 +126,46 @@ class ModelsTable: } ) try: - result = Model(**model.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Model(**model.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - return [ - ModelModel.model_validate(model) for model in Session.query(Model).all() - ] + with get_db() as db: + + return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - model = Session.get(Model, id) - return ModelModel.model_validate(model) + with get_db() as db: + + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: - # update only the fields that are present in the model - model = Session.query(Model).get(id) - model.update(**model.model_dump()) - Session.commit() - Session.refresh(model) - return ModelModel.model_validate(model) + with get_db() as db: + + # update only the fields that are present in the model + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -166,8 +173,10 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - Session.query(Model).filter_by(id=id).delete() - return True + with get_db() as db: + + db.query(Model).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index ab8cc04ce..2af2ce22c 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -4,7 +4,7 @@ import time from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -60,46 +60,56 @@ class PromptsTable: ) try: - result = Prompt(**prompt.dict()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return PromptModel.model_validate(result) + else: + return None except Exception as e: return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) except: return None def get_prompts(self) -> List[PromptModel]: - return [ - PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() - ] + with get_db() as db: + + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - Session.commit() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + db.commit() + return PromptModel.model_validate(prompt) except: return None def delete_prompt_by_command(self, command: str) -> bool: try: - Session.query(Prompt).filter_by(command=command).delete() - return True + with get_db() as db: + + db.query(Prompt).filter_by(command=command).delete() + return True except: return False diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7b0df6b6b..bbbc95ed2 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -79,26 +79,29 @@ class ChatTagsResponse(BaseModel): class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - id = str(uuid.uuid4()) - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) - try: - result = Tag(**tag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return TagModel.model_validate(result) - else: + with get_db() as db: + + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + try: + result = Tag(**tag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None + except Exception as e: return None - except Exception as e: - return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = Session.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + with get_db() as db: + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -120,98 +123,109 @@ class TagTable: } ) try: - result = ChatIdTag(**chatIdTag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + with get_db() as db: + result = ChatIdTag(**chatIdTag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] + + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - return ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + with get_db() as db: + + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + return True except Exception as e: log.error(f"delete_tag: {e}") return False @@ -220,20 +234,24 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + + return True except Exception as e: log.error(f"delete_tag: {e}") return False diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index f5df10637..dc0fe01c5 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -4,7 +4,7 @@ import time import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users import json @@ -83,54 +83,63 @@ class ToolsTable: def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ToolModel.model_validate(result) - else: + with get_db() as db: + + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool(**tool.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - return ToolModel.model_validate(tool) + with get_db() as db: + + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - tool = Session.get(Tool, id) - return tool.valves if tool.valves else {} + with get_db() as db: + + tool = db.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - Session.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_tool_by_id(id) + with get_db() as db: + + db.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None @@ -177,19 +186,21 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - tool.update(**updated) - tool.updated_at = int(time.time()) - Session.commit() - Session.refresh(tool) - return ToolModel.model_validate(tool) + with get_db() as db: + tool = db.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + db.commit() + db.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - Session.query(Tool).filter_by(id=id).delete() - return True + with get_db() as db: + db.query(Tool).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 9e1e25ac6..8e3b57bba 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, Session, get_db from apps.webui.models.chats import Chats #################### @@ -88,81 +88,92 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return user - else: - return None + with get_db() as db: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except Exception as e: return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) except: return None def get_user_by_email(self, email: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) except: return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) except: return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - users = ( - Session.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + with get_db() as db: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - return Session.query(User).count() + with get_db() as db: + return db.query(User).count() def get_first_user(self) -> UserModel: try: - user = Session.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) except: return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"role": role}) - Session.commit() - - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + db.query(User).filter_by(id=id).update({"role": role}) + db.commit() + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -170,25 +181,28 @@ class UsersTable: self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) - Session.commit() + with get_db() as db: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -196,21 +210,23 @@ class UsersTable: self, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + with get_db() as db: + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update(updated) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update(updated) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) except Exception as e: return None @@ -220,9 +236,10 @@ class UsersTable: result = Chats.delete_chats_by_user_id(id) if result: - # Delete User - Session.query(User).filter_by(id=id).delete() - Session.commit() + with get_db() as db: + # Delete User + db.query(User).filter_by(id=id).delete() + db.commit() return True else: @@ -232,16 +249,18 @@ class UsersTable: def update_user_api_key_by_id(self, id: str, api_key: str) -> str: try: - result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) - Session.commit() - return True if result == 1 else False + with get_db() as db: + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() + return True if result == 1 else False except: return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: - user = Session.query(User).filter_by(id=id).first() - return user.api_key + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return user.api_key except Exception as e: return None