diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 7420bd019..a4b86ae1f 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -4,15 +4,13 @@ import json from peewee import * from peewee_migrate import Router -from playhouse.db_url import connect -from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases +from apps.webui.internal.wrappers import register_connection from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) - class JSONField(TextField): def db_value(self, value): return json.dumps(value) @@ -21,9 +19,6 @@ class JSONField(TextField): if value is not None: return json.loads(value) - -register_peewee_databases() - # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -32,13 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass -DB = connect(DATABASE_URL) -DB._state = PeeweeConnectionState() -log.info(f"Connected to a {DB.__class__.__name__} database.") + +# The `register_connection` function encapsulates the logic for setting up +# the database connection based on the connection string, while `connect` +# is a Peewee-specific method to manage the connection state and avoid errors +# when a connection is already open. +try: + DB = register_connection(DATABASE_URL) + log.info(f"Connected to a {DB.__class__.__name__} database.") +except Exception as e: + log.error(f"Failed to initialize the database connection: {e}") + raise + router = Router( DB, migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", logger=log, ) router.run() -DB.connect(reuse_if_open=True) +try: + DB.connect() +except OperationalError as e: + log.info(f"Failed to connect to database again due to: {e}") + pass \ No newline at end of file diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py index 53869b916..11c91034f 100644 --- a/backend/apps/webui/internal/wrappers.py +++ b/backend/apps/webui/internal/wrappers.py @@ -1,18 +1,13 @@ from contextvars import ContextVar - -from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, _ConnectionState -from playhouse.db_url import register_database +from peewee import * +from playhouse.db_url import connect from playhouse.pool import PooledPostgresqlDatabase from playhouse.shortcuts import ReconnectMixin -from psycopg2 import OperationalError -from psycopg2.errors import InterfaceError - db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} db_state = ContextVar("db_state", default=db_state_default.copy()) - -class PeeweeConnectionState(_ConnectionState): +class PeeweeConnectionState(object): def __init__(self, **kwargs): super().__setattr__("_state", db_state) super().__init__(**kwargs) @@ -21,29 +16,29 @@ class PeeweeConnectionState(_ConnectionState): self._state.get()[name] = value def __getattr__(self, name): - return self._state.get()[name] + value = self._state.get()[name] + return value +class ReconnectingPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase): + pass -class CustomReconnectMixin(ReconnectMixin): - reconnect_errors = ( - # default ReconnectMixin exceptions - *ReconnectMixin.reconnect_errors, - # psycopg2 - (OperationalError, 'termin'), - (InterfaceError, 'closed'), - # peewee - (PeeWeeInterfaceError, 'closed'), - ) +class ReconnectingPooledPostgresqlDatabase(ReconnectMixin, PooledPostgresqlDatabase): + pass - -class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): +class ReconnectingSqliteDatabase(ReconnectMixin, SqliteDatabase): pass -class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase): - pass - - -def register_peewee_databases(): - register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql') - register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool') +def register_connection(db_url): + # Connect using the playhouse.db_url module, which supports multiple + # database types, then wrap the connection in a ReconnectMixin to handle dropped connections + db = connect(db_url) + if isinstance(db, PostgresqlDatabase): + db = ReconnectingPostgresqlDatabase(db.database, **db.connect_params) + elif isinstance(db, PooledPostgresqlDatabase): + db = ReconnectingPooledPostgresqlDatabase(db.database, **db.connect_params) + elif isinstance(db, SqliteDatabase): + db = ReconnectingSqliteDatabase(db.database, **db.connect_params) + else: + raise ValueError('Unsupported database connection') + return db