diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py index d1cdecd29..2b5551ce2 100644 --- a/backend/apps/webui/internal/wrappers.py +++ b/backend/apps/webui/internal/wrappers.py @@ -4,6 +4,7 @@ from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError import logging from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin from config import SRC_LOG_LEVELS @@ -13,6 +14,7 @@ log.setLevel(SRC_LOG_LEVELS["DB"]) db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} db_state = ContextVar("db_state", default=db_state_default.copy()) + class PeeweeConnectionState(object): def __init__(self, **kwargs): super().__setattr__("_state", db_state) @@ -25,18 +27,21 @@ class PeeweeConnectionState(object): value = self._state.get()[name] return value + class CustomReconnectMixin(ReconnectMixin): reconnect_errors = ( # psycopg2 - (OperationalError, 'termin'), - (InterfaceError, 'closed'), + (OperationalError, "termin"), + (InterfaceError, "closed"), # peewee - (PeeWeeInterfaceError, 'closed'), + (PeeWeeInterfaceError, "closed"), ) + class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): pass + def register_connection(db_url): db = connect(db_url) if isinstance(db, PostgresqlDatabase): @@ -44,8 +49,18 @@ def register_connection(db_url): db.autoconnect = True db.reuse_if_open = True log.info("Connected to PostgreSQL database") + + # Get the connection details connection = parse(db_url) - db = ReconnectingPostgresqlDatabase(connection['database'], user=connection['user'], password=connection['password'],host=connection['host'], port=connection['port']) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase( + connection["database"], + user=connection["user"], + password=connection["password"], + host=connection["host"], + port=connection["port"], + ) db.connect(reuse_if_open=True) elif isinstance(db, SqliteDatabase): # Enable autoconnect for SQLite databases, managed by Peewee @@ -53,5 +68,5 @@ def register_connection(db_url): db.reuse_if_open = True log.info("Connected to SQLite database") else: - raise ValueError('Unsupported database connection') - return db \ No newline at end of file + raise ValueError("Unsupported database connection") + return db