From da403f3e3cf9ce700da2fdb477e0bdfc4794d37d Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:06:15 +0200 Subject: [PATCH] feat(sqlalchemy): use session factory instead of context manager --- backend/apps/webui/internal/db.py | 12 +- backend/apps/webui/models/auths.py | 131 ++++---- backend/apps/webui/models/chats.py | 292 ++++++++---------- backend/apps/webui/models/documents.py | 77 +++-- backend/apps/webui/models/files.py | 36 +-- backend/apps/webui/models/functions.py | 117 ++++--- backend/apps/webui/models/memories.py | 60 ++-- backend/apps/webui/models/models.py | 42 ++- backend/apps/webui/models/prompts.py | 91 +++--- backend/apps/webui/models/tags.py | 200 ++++++------ backend/apps/webui/models/tools.py | 59 ++-- backend/apps/webui/models/users.py | 245 +++++++-------- backend/main.py | 14 +- backend/test/apps/webui/routers/test_chats.py | 2 + .../test/util/abstract_integration_test.py | 21 +- 15 files changed, 640 insertions(+), 759 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index b9bfc8aff..320ab3e07 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -57,14 +57,4 @@ SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) Base = declarative_base() - - -@contextmanager -def get_session(): - session = scoped_session(SessionLocal) - try: - yield session - session.commit() - except Exception as e: - session.rollback() - raise e +Session = scoped_session(SessionLocal) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 9f10e0fdd..1858b2c0d 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -3,12 +3,11 @@ from typing import Optional import uuid import logging from sqlalchemy import String, Column, Boolean -from sqlalchemy.orm import Session from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session from config import SRC_LOG_LEVELS @@ -103,101 +102,93 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_session() as db: - log.info("insert_new_auth") + log.info("insert_new_auth") - id = str(uuid.uuid4()) + id = str(uuid.uuid4()) - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - db.add(result) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + Session.add(result) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub - ) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub) - db.commit() - db.refresh(result) + Session.commit() + Session.refresh(result) - if result and user: - return user - else: - return None + if result and user: + return user + else: + return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - with get_session() as db: - try: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) - return user - else: - return None + try: + auth = Session.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user else: return None - except: + else: return None + except: + return None def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") - with get_session() as db: - # if no api_key, return None - if not api_key: - return None + # if no api_key, return None + if not api_key: + return None - try: - user = Users.get_user_by_api_key(api_key) - return user if user else None - except: - return False + try: + user = Users.get_user_by_api_key(api_key) + return user if user else None + except: + return False def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") - with get_session() as db: - try: - auth = db.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user - except: - return None + try: + auth = Session.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user + except: + return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: - with get_session() as db: - try: - result = ( - db.query(Auth).filter_by(id=id).update({"password": new_password}) - ) - return True if result == 1 else False - except: - return False + try: + result = ( + Session.query(Auth).filter_by(id=id).update({"password": new_password}) + ) + return True if result == 1 else False + except: + return False def update_email_by_id(self, id: str, email: str) -> bool: - with get_session() as db: - try: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False - except: - return False + try: + result = Session.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False + except: + return False def delete_auth_by_id(self, id: str) -> bool: - with get_session() as db: - try: - # Delete User - result = Users.delete_user_by_id(id) + try: + # Delete User + result = Users.delete_user_by_id(id) - if result: - db.query(Auth).filter_by(id=id).delete() + if result: + Session.query(Auth).filter_by(id=id).delete() - return True - else: - return False - except: + return True + else: return False + except: + return False Auths = AuthsTable() diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index b0c983ade..abf5b544c 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -6,9 +6,8 @@ import uuid import time from sqlalchemy import Column, String, BigInteger, Boolean -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session #################### @@ -80,93 +79,88 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - with get_session() as db: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) - result = Chat(**chat.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - return ChatModel.model_validate(result) if result else None + result = Chat(**chat.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: - with get_session() as db: - try: - chat_obj = db.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - db.commit() - db.refresh(chat_obj) + try: + chat_obj = Session.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + Session.commit() + Session.refresh(chat_obj) - return ChatModel.model_validate(chat_obj) - except Exception as e: - return None + return ChatModel.model_validate(chat_obj) + except Exception as e: + return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_session() as db: - # Get the existing chat to share - chat = db.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - db.add(shared_result) - db.commit() - db.refresh(shared_result) - # Update the original chat with the share_id - result = ( - db.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + # Get the existing chat to share + chat = Session.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + Session.add(shared_result) + Session.commit() + Session.refresh(shared_result) + # Update the original chat with the share_id + result = ( + Session.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) - return shared_chat if (shared_result and result) else None + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_session() as db: - try: - print("update_shared_chat_by_id") - chat = db.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - db.commit() - db.refresh(chat) + try: + print("update_shared_chat_by_id") + chat = Session.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + Session.commit() + Session.refresh(chat) - return self.get_chat_by_id(chat.share_id) - except: - return None + return self.get_chat_by_id(chat.share_id) + except: + return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() return True except: return False @@ -175,30 +169,27 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.get(Chat, id) - chat.share_id = share_id - db.commit() - db.refresh(chat) - return chat + chat = Session.get(Chat, id) + chat.share_id = share_id + Session.commit() + Session.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = self.get_chat_by_id(id) - db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) - - return self.get_chat_by_id(id) + chat = Session.get(Chat, id) + chat.archived = not chat.archived + Session.commit() + Session.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - + Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) return True except: return False @@ -206,9 +197,8 @@ class ChatTable: def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - with get_session() as db: all_chats = ( - db.query(Chat) + Session.query(Chat) .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) # .limit(limit).offset(skip) @@ -223,120 +213,108 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - with get_session() as db: - query = db.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + query = Session.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.get(Chat, id) - return ChatModel.model_validate(chat) + chat = Session.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.query(Chat).filter_by(share_id=id).first() + chat = Session.query(Chat).filter_by(share_id=id).first() - if chat: - return self.get_chat_by_id(id) - else: - return None + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(id=id).delete() + Session.query(Chat).filter_by(id=id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(id=id, user_id=user_id).delete() + Session.query(Chat).filter_by(id=id, user_id=user_id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - self.delete_shared_chats_by_user_id(user_id) + self.delete_shared_chats_by_user_id(user_id) - db.query(Chat).filter_by(user_id=user_id).delete() + Session.query(Chat).filter_by(user_id=user_id).delete() return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() return True except: diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 897f182be..f8e7153c5 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session import json @@ -84,46 +83,42 @@ class DocumentsTable: ) try: - with get_session() as db: - result = Document(**document.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: - return None + result = Document(**document.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None except: return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - with get_session() as db: - document = db.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + document = Session.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - with get_session() as db: - return [ - DocumentModel.model_validate(doc) for doc in db.query(Document).all() - ] + return [ + DocumentModel.model_validate(doc) for doc in Session.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - with get_session() as db: - db.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(form_data.name) + Session.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + Session.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -132,27 +127,25 @@ class DocumentsTable: self, name: str, updated: dict ) -> Optional[DocumentModel]: try: - with get_session() as db: - doc = self.get_doc_by_name(name) - doc_content = json.loads(doc.content if doc.content else "{}") - doc_content = {**doc_content, **updated} + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} - db.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(name) + Session.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + Session.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - with get_session() as db: - db.query(Document).filter_by(name=name).delete() + Session.query(Document).filter_by(name=name).delete() return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index b7196d604..7664bf4f1 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import Column, String, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base, get_session +from apps.webui.internal.db import JSONField, Base, Session import json @@ -71,45 +70,38 @@ class FilesTable: ) try: - with get_session() as db: - result = File(**file.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FileModel.model_validate(result) - else: - return None + result = File(**file.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_file_by_id(self, id: str) -> Optional[FileModel]: try: - with get_session() as db: - file = db.get(File, id) - return FileModel.model_validate(file) + file = Session.get(File, id) + return FileModel.model_validate(file) except: return None def get_files(self) -> List[FileModel]: - with get_session() as db: - return [FileModel.model_validate(file) for file in db.query(File).all()] + return [FileModel.model_validate(file) for file in Session.query(File).all()] def delete_file_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(File).filter_by(id=id).delete() - db.commit() + Session.query(File).filter_by(id=id).delete() return True except: return False def delete_all_files(self) -> bool: try: - with get_session() as db: - db.query(File).delete() - db.commit() + Session.query(File).delete() return True except: return False diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 2343c9139..b78ac9708 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base, get_session +from apps.webui.internal.db import JSONField, Base, Session from apps.webui.models.users import Users import json @@ -100,64 +99,57 @@ class FunctionsTable: ) try: - with get_session() as db: - result = Function(**function.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + result = Function(**function.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - with get_session() as db: - function = db.get(Function, id) - return FunctionModel.model_validate(function) + function = Session.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(is_active=True).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).filter_by(is_active=True).all() + ] else: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: if active_only: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] else: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(type=type).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).filter_by(type=type).all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: try: - with get_session() as db: - function = db.get(Function, id) - return function.valves if function.valves else {} + function = Session.get(Function, id) + return function.valves if function.valves else {} except Exception as e: print(f"An error occurred: {e}") return None @@ -166,12 +158,12 @@ class FunctionsTable: self, id: str, valves: dict ) -> Optional[FunctionValves]: try: - with get_session() as db: - db.query(Function).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - db.commit() - return self.get_function_by_id(id) + function = Session.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + Session.commit() + Session.refresh(function) + return self.get_function_by_id(id) except: return None @@ -219,36 +211,33 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: - with get_session() as db: - db.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - db.commit() - return self.get_function_by_id(id) + Session.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + Session.commit() + return self.get_function_by_id(id) except: return None def deactivate_all_functions(self) -> Optional[bool]: try: - with get_session() as db: - db.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - db.commit() + Session.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + Session.commit() return True except: return None def delete_function_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Function).filter_by(id=id).delete() + Session.query(Function).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 941da5b26..263d1b5ab 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session -from apps.webui.models.chats import Chats +from apps.webui.internal.db import Base, Session import time import uuid @@ -58,15 +56,14 @@ class MemoriesTable: "updated_at": int(time.time()), } ) - with get_session() as db: - result = Memory(**memory.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + result = Memory(**memory.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, @@ -74,62 +71,55 @@ class MemoriesTable: content: str, ) -> Optional[MemoryModel]: try: - with get_session() as db: - db.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - db.commit() - return self.get_memory_by_id(id) + Session.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + Session.commit() + return self.get_memory_by_id(id) except: return None def get_memories(self) -> List[MemoryModel]: try: - with get_session() as db: - memories = db.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] + memories = Session.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: try: - with get_session() as db: - memories = db.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] + memories = Session.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: try: - with get_session() as db: - memory = db.get(Memory, id) - return MemoryModel.model_validate(memory) + memory = Session.get(Memory, id) + return MemoryModel.model_validate(memory) except: return None def delete_memory_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(id=id).delete() + Session.query(Memory).filter_by(id=id).delete() return True except: return False - def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: + def delete_memories_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(user_id=user_id).delete() + Session.query(Memory).filter_by(user_id=user_id).delete() return True except: return False def delete_memory_by_id_and_user_id( - self, db: Session, id: str, user_id: str + self, id: str, user_id: str ) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(id=id, user_id=user_id).delete() + Session.query(Memory).filter_by(id=id, user_id=user_id).delete() return True except: return False diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 86b4fa49b..dd736a73e 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -4,9 +4,8 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -127,41 +126,37 @@ class ModelsTable: } ) try: - with get_session() as db: - result = Model(**model.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + result = Model(**model.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - with get_session() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + return [ModelModel.model_validate(model) for model in Session.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - with get_session() as db: - model = db.get(Model, id) - return ModelModel.model_validate(model) + model = Session.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: # update only the fields that are present in the model - with get_session() as db: - model = db.query(Model).get(id) - model.update(**model.model_dump()) - db.commit() - db.refresh(model) - return ModelModel.model_validate(model) + model = Session.query(Model).get(id) + model.update(**model.model_dump()) + Session.commit() + Session.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -169,8 +164,7 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Model).filter_by(id=id).delete() + Session.query(Model).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 029fd5e1b..a2fd0366b 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -3,9 +3,8 @@ from typing import List, Optional import time from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session import json @@ -50,65 +49,59 @@ class PromptsTable: def insert_new_prompt( self, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: - with get_session() as db: - prompt = PromptModel( - **{ - "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, - "timestamp": int(time.time()), - } - ) + prompt = PromptModel( + **{ + "user_id": user_id, + "command": form_data.command, + "title": form_data.title, + "content": form_data.content, + "timestamp": int(time.time()), + } + ) - try: - result = Prompt(**prompt.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None - except Exception as e: + try: + result = Prompt(**prompt.dict()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return PromptModel.model_validate(result) + else: return None + except Exception as e: + return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: - with get_session() as db: - try: - prompt = db.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) - except: - return None + try: + prompt = Session.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) + except: + return None def get_prompts(self) -> List[PromptModel]: - with get_session() as db: - return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() - ] + return [ + PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: - with get_session() as db: - try: - prompt = db.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - db.commit() - return prompt - # return self.get_prompt_by_command(command) - except: - return None + try: + prompt = Session.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + Session.commit() + return PromptModel.model_validate(prompt) + except: + return None def delete_prompt_by_command(self, command: str) -> bool: - with get_session() as db: - try: - db.query(Prompt).filter_by(command=command).delete() - return True - except: - return False + try: + Session.query(Prompt).filter_by(command=command).delete() + return True + except: + return False Prompts = PromptsTable() diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index dfe63688e..6cfe39d0c 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -7,9 +7,8 @@ import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session from config import SRC_LOG_LEVELS @@ -83,15 +82,14 @@ class TagTable: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: - with get_session() as db: - result = Tag(**tag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return TagModel.model_validate(result) - else: - return None + result = Tag(**tag.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None except Exception as e: return None @@ -99,9 +97,8 @@ class TagTable: self, name: str, user_id: str ) -> Optional[TagModel]: try: - with get_session() as db: - tag = db.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + tag = Session.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -123,105 +120,99 @@ class TagTable: } ) try: - with get_session() as db: - result = ChatIdTag(**chatIdTag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + result = ChatIdTag(**chatIdTag.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - with get_session() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + Session.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - with get_session() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + Session.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - with get_session() as db: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - with get_session() as db: - return ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + return ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - with get_session() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() + res = ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: log.error(f"delete_tag: {e}") @@ -231,21 +222,20 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - with get_session() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() + res = ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 534a4e3e8..20c608921 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -3,9 +3,8 @@ from typing import List, Optional import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from apps.webui.models.users import Users import json @@ -95,48 +94,43 @@ class ToolsTable: ) try: - with get_session() as db: - result = Tool(**tool.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ToolModel.model_validate(result) - else: - return None + result = Tool(**tool.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - with get_session() as db: - tool = db.get(Tool, id) - return ToolModel.model_validate(tool) + tool = Session.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - with get_session() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - with get_session() as db: - tool = db.get(Tool, id) - return tool.valves if tool.valves else {} + tool = Session.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - db.commit() - return self.get_tool_by_id(id) + Session.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + Session.commit() + return self.get_tool_by_id(id) except: return None @@ -183,19 +177,18 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) - db.commit() - return self.get_tool_by_id(id) + tool = Session.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + Session.commit() + Session.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).delete() + Session.query(Tool).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 796892927..252e3f122 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -3,11 +3,10 @@ from typing import List, Union, Optional import time from sqlalchemy import String, Column, BigInteger, Text -from sqlalchemy.orm import Session from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from apps.webui.models.chats import Chats #################### @@ -89,177 +88,161 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_session() as db: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return user - else: - return None + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except Exception as e: - return None + try: + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except Exception as e: + return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) + except: + return None def get_user_by_email(self, email: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) + except: + return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) + except: + return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - with get_session() as db: - users = ( - db.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + users = ( + Session.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - with get_session() as db: - return db.query(User).count() + return Session.query(User).count() def get_first_user(self) -> UserModel: - with get_session() as db: - try: - user = db.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) + except: + return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update({"role": role}) - db.commit() + try: + Session.query(User).filter_by(id=id).update({"role": role}) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_profile_image_url_by_id( self, id: str, profile_image_url: str ) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - db.commit() + try: + Session.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) + try: + Session.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_oauth_sub_by_id( self, id: str, oauth_sub: str ) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + try: + Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update(updated) - db.commit() + try: + Session.query(User).filter_by(id=id).update(updated) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) - except Exception as e: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) + except Exception as e: + return None def delete_user_by_id(self, id: str) -> bool: - with get_session() as db: - try: - # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + try: + # Delete User Chats + result = Chats.delete_chats_by_user_id(id) - if result: - # Delete User - db.query(User).filter_by(id=id).delete() - db.commit() + if result: + # Delete User + Session.query(User).filter_by(id=id).delete() + Session.commit() - return True - else: - return False - except: + return True + else: return False + except: + return False def update_user_api_key_by_id(self, id: str, api_key: str) -> str: - with get_session() as db: - try: - result = db.query(User).filter_by(id=id).update({"api_key": api_key}) - db.commit() - return True if result == 1 else False - except: - return False + try: + result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) + Session.commit() + return True if result == 1 else False + except: + return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: - with get_session() as db: - try: - user = db.query(User).filter_by(id=id).first() - return user.api_key - except Exception as e: - return None + try: + user = Session.query(User).filter_by(id=id).first() + return user.api_key + except Exception as e: + return None Users = UsersTable() diff --git a/backend/main.py b/backend/main.py index f35095bf1..2120b499a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -29,7 +29,6 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import text -from sqlalchemy.orm import Session from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -57,7 +56,7 @@ from apps.webui.main import ( get_pipe_models, generate_function_chat_completion, ) -from apps.webui.internal.db import get_session, SessionLocal +from apps.webui.internal.db import Session, SessionLocal from pydantic import BaseModel @@ -794,6 +793,14 @@ app.add_middleware( allow_headers=["*"], ) +@app.middleware("http") +async def remove_session_after_request(request: Request, call_next): + response = await call_next(request) + log.debug("Removing session after request") + Session.commit() + Session.remove() + return response + @app.middleware("http") async def check_url(request: Request, call_next): @@ -2034,8 +2041,7 @@ async def healthcheck(): @app.get("/health/db") async def healthcheck_with_db(): - with get_session() as db: - result = db.execute(text("SELECT 1;")).all() + Session.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index ea4518eaf..6d2dd35b1 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") + from apps.webui.internal.db import Session + Session.commit() with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/all/archived")) assert response.status_code == 200 diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index 781fbfff8..f8d6d4ff7 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip from fastapi.testclient import TestClient from sqlalchemy import text, create_engine + log = logging.getLogger(__name__) @@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest): DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" docker_client: DockerClient - def get_db(self): - from apps.webui.internal.db import SessionLocal - - return SessionLocal() - @classmethod def _create_db_url(cls, env_vars_postgres: dict) -> str: host = get_docker_ip() @@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest): pytest.fail(f"Could not setup test environment: {ex}") def _check_db_connection(self): + from apps.webui.internal.db import Session retries = 10 while retries > 0: try: - self.db_session.execute(text("SELECT 1")) - self.db_session.commit() + Session.execute(text("SELECT 1")) + Session.commit() break except Exception as e: - self.db_session.rollback() + Session.rollback() log.warning(e) time.sleep(3) retries -= 1 def setup_method(self): super().setup_method() - self.db_session = self.get_db() self._check_db_connection() @classmethod @@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest): cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) def teardown_method(self): + from apps.webui.internal.db import Session # rollback everything not yet committed - self.db_session.commit() + Session.commit() # truncate all tables tables = [ @@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest): '"user"', ] for table in tables: - self.db_session.execute(text(f"TRUNCATE TABLE {table}")) - self.db_session.commit() + Session.execute(text(f"TRUNCATE TABLE {table}")) + Session.commit()