From 1436bb7c61b1df4dba2b5b383ecb8c86ec452f37 Mon Sep 17 00:00:00 2001
From: "Timothy J. Baek" <timothyjrbeck@gmail.com>
Date: Fri, 5 Jul 2024 23:38:53 -0700
Subject: [PATCH] enh: handle peewee migration
---
backend/apps/webui/internal/db.py | 36 +++++++++++--
backend/apps/webui/internal/wrappers.py | 72 +++++++++++++++++++++++++
2 files changed, 105 insertions(+), 3 deletions(-)
create mode 100644 backend/apps/webui/internal/wrappers.py
diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py
index 333e215ea..8437ae4fa 100644
--- a/backend/apps/webui/internal/db.py
+++ b/backend/apps/webui/internal/db.py
@@ -2,6 +2,10 @@ import os
import logging
import json
from contextlib import contextmanager
+
+from peewee_migrate import Router
+from apps.webui.internal.wrappers import register_connection
+
from typing import Optional, Any
from typing_extensions import Self
@@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else:
pass
+
+# Workaround to handle the peewee migration
+# This is required to ensure the peewee migration is handled before the alembic migration
+def handle_peewee_migration():
+ try:
+ db = register_connection(DATABASE_URL)
+ migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
+ router = Router(db, logger=log, migrate_dir=migrate_dir)
+ router.run()
+ db.close()
+
+ # check if db connection has been closed
+
+ except Exception as e:
+ log.error(f"Failed to initialize the database connection: {e}")
+ raise
+
+ finally:
+ # Properly closing the database connection
+ if db and not db.is_closed():
+ db.close()
+
+ # Assert if db connection has been closed
+ assert db.is_closed(), "Database connection is still open."
+
+
+handle_peewee_migration()
+
+
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
@@ -62,9 +95,6 @@ Base = declarative_base()
Session = scoped_session(SessionLocal)
-from contextlib import contextmanager
-
-
# Dependency
def get_session():
db = SessionLocal()
diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py
new file mode 100644
index 000000000..2b5551ce2
--- /dev/null
+++ b/backend/apps/webui/internal/wrappers.py
@@ -0,0 +1,72 @@
+from contextvars import ContextVar
+from peewee import *
+from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
+
+import logging
+from playhouse.db_url import connect, parse
+from playhouse.shortcuts import ReconnectMixin
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["DB"])
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+
+class PeeweeConnectionState(object):
+ def __init__(self, **kwargs):
+ super().__setattr__("_state", db_state)
+ super().__init__(**kwargs)
+
+ def __setattr__(self, name, value):
+ self._state.get()[name] = value
+
+ def __getattr__(self, name):
+ value = self._state.get()[name]
+ return value
+
+
+class CustomReconnectMixin(ReconnectMixin):
+ reconnect_errors = (
+ # psycopg2
+ (OperationalError, "termin"),
+ (InterfaceError, "closed"),
+ # peewee
+ (PeeWeeInterfaceError, "closed"),
+ )
+
+
+class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+ pass
+
+
+def register_connection(db_url):
+ db = connect(db_url)
+ if isinstance(db, PostgresqlDatabase):
+ # Enable autoconnect for SQLite databases, managed by Peewee
+ db.autoconnect = True
+ db.reuse_if_open = True
+ log.info("Connected to PostgreSQL database")
+
+ # Get the connection details
+ connection = parse(db_url)
+
+ # Use our custom database class that supports reconnection
+ db = ReconnectingPostgresqlDatabase(
+ connection["database"],
+ user=connection["user"],
+ password=connection["password"],
+ host=connection["host"],
+ port=connection["port"],
+ )
+ db.connect(reuse_if_open=True)
+ elif isinstance(db, SqliteDatabase):
+ # Enable autoconnect for SQLite databases, managed by Peewee
+ db.autoconnect = True
+ db.reuse_if_open = True
+ log.info("Connected to SQLite database")
+ else:
+ raise ValueError("Unsupported database connection")
+ return db