diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 2426aff27..e64f93bc1 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -170,6 +170,26 @@ jobs: echo "Server has stopped" exit 1 fi + + # Check that service will reconnect to postgres when connection will be closed + status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags) + if [[ "$status_code" -ne 200 ]] ; then + echo "Server has failed before postgres reconnect check" + exit 1 + fi + + echo "Terminating all connections to postgres..." + python -c "import os, psycopg2 as pg2; \ + conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \ + cur = conn.cursor(); \ + cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')" + + status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags) + if [[ "$status_code" -ne 200 ]] ; then + echo "Server has not reconnected to postgres after connection was closed: returned status $status_code" + exit 1 + fi + # - name: Test backend with MySQL # if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure' diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 0e7b1f95d..b61eb012d 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,16 +1,16 @@ +import os +import logging import json from peewee import * from peewee_migrate import Router -from playhouse.db_url import connect + +from apps.webui.internal.wrappers import register_connection from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR -import os -import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) - class JSONField(TextField): def db_value(self, value): return json.dumps(value) @@ -19,7 +19,6 @@ class JSONField(TextField): if value is not None: return json.loads(value) - # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file @@ -28,12 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass -DB = connect(DATABASE_URL) -log.info(f"Connected to a {DB.__class__.__name__} database.") + +# The `register_connection` function encapsulates the logic for setting up +# the database connection based on the connection string, while `connect` +# is a Peewee-specific method to manage the connection state and avoid errors +# when a connection is already open. +try: + DB = register_connection(DATABASE_URL) + log.info(f"Connected to a {DB.__class__.__name__} database.") +except Exception as e: + log.error(f"Failed to initialize the database connection: {e}") + raise + router = Router( DB, migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", logger=log, ) router.run() -DB.connect(reuse_if_open=True) +try: + DB.connect(reuse_if_open=True) +except OperationalError as e: + log.info(f"Failed to connect to database again due to: {e}") + pass \ No newline at end of file diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..2b5551ce2 --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,72 @@ +from contextvars import ContextVar +from peewee import * +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError + +import logging +from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["DB"]) + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(object): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + value = self._state.get()[name] + return value + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # psycopg2 + (OperationalError, "termin"), + (InterfaceError, "closed"), + # peewee + (PeeWeeInterfaceError, "closed"), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +def register_connection(db_url): + db = connect(db_url) + if isinstance(db, PostgresqlDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to PostgreSQL database") + + # Get the connection details + connection = parse(db_url) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase( + connection["database"], + user=connection["user"], + password=connection["password"], + host=connection["host"], + port=connection["port"], + ) + db.connect(reuse_if_open=True) + elif isinstance(db, SqliteDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to SQLite database") + else: + raise ValueError("Unsupported database connection") + return db