From dfbc12594702bf601c4c391cfe080f8e0844d01e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=91=D0=B5=D0=BA=D0=BB=D0=B5=D0=BC=D0=B8=D1=88=D0=B5?= =?UTF-8?q?=D0=B2=20=D0=9F=D0=B5=D1=82=D1=80=20=D0=90=D0=BB=D0=B5=D0=BA?= =?UTF-8?q?=D1=81=D0=B5=D0=B5=D0=B2=D0=B8=D1=87?= Date: Thu, 30 May 2024 18:55:58 +0700 Subject: [PATCH] Reconnect to postgresql & mysql external databases when getting disconnected --- backend/apps/webui/internal/db.py | 9 ++++ backend/apps/webui/internal/wrappers.py | 59 +++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 backend/apps/webui/internal/wrappers.py diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 0e7b1f95d..dda58a4e1 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -7,6 +7,12 @@ from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR import os import logging +from peewee_migrate import Router +from playhouse.db_url import connect + +from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases +from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) @@ -20,6 +26,8 @@ class JSONField(TextField): return json.loads(value) +register_peewee_databases() + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -29,6 +37,7 @@ else: pass DB = connect(DATABASE_URL) +DB._state = PeeweeConnectionState() log.info(f"Connected to a {DB.__class__.__name__} database.") router = Router( DB, diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..406599b5a --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,59 @@ +from contextvars import ContextVar + +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState +from playhouse.db_url import register_database +from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase +from playhouse.shortcuts import ReconnectMixin +from psycopg2 import OperationalError +from psycopg2.errors import InterfaceError + + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(_ConnectionState): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + return self._state.get()[name] + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # default ReconnectMixin exceptions (MySQL specific) + *ReconnectMixin.reconnect_errors, + # psycopg2 + (OperationalError, 'termin'), + (InterfaceError, 'closed'), + # peewee + (PeeWeeInterfaceError, 'closed'), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase): + pass + + +class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase): + pass + + +class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase): + pass + + +def register_peewee_databases(): + register_database(MySQLDatabase, 'mysql') + register_database(PooledMySQLDatabase, 'mysql+pool') + register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql') + register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')