diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 0e7b1f95d..dda58a4e1 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -7,6 +7,12 @@ from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR import os import logging +from peewee_migrate import Router +from playhouse.db_url import connect + +from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases +from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) @@ -20,6 +26,8 @@ class JSONField(TextField): return json.loads(value) +register_peewee_databases() + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -29,6 +37,7 @@ else: pass DB = connect(DATABASE_URL) +DB._state = PeeweeConnectionState() log.info(f"Connected to a {DB.__class__.__name__} database.") router = Router( DB, diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..406599b5a --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,59 @@ +from contextvars import ContextVar + +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState +from playhouse.db_url import register_database +from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase +from playhouse.shortcuts import ReconnectMixin +from psycopg2 import OperationalError +from psycopg2.errors import InterfaceError + + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(_ConnectionState): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + return self._state.get()[name] + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # default ReconnectMixin exceptions (MySQL specific) + *ReconnectMixin.reconnect_errors, + # psycopg2 + (OperationalError, 'termin'), + (InterfaceError, 'closed'), + # peewee + (PeeWeeInterfaceError, 'closed'), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase): + pass + + +class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase): + pass + + +class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase): + pass + + +def register_peewee_databases(): + register_database(MySQLDatabase, 'mysql') + register_database(PooledMySQLDatabase, 'mysql+pool') + register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql') + register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')