fix peewee and playhouse connections to retry

This commit is contained in:
perf3ct 2024-06-16 15:25:48 -07:00
parent 75d713057c
commit 10fa887eab
No known key found for this signature in database
GPG Key ID: 569C4EEC436F5232
2 changed files with 41 additions and 38 deletions

View File

@ -4,15 +4,13 @@ import json
from peewee import *
from peewee_migrate import Router
from playhouse.db_url import connect
from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases
from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
def db_value(self, value):
return json.dumps(value)
@ -21,9 +19,6 @@ class JSONField(TextField):
if value is not None:
return json.loads(value)
register_peewee_databases()
# Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
@ -32,13 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else:
pass
DB = connect(DATABASE_URL)
DB._state = PeeweeConnectionState()
log.info(f"Connected to a {DB.__class__.__name__} database.")
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run()
DB.connect(reuse_if_open=True)
try:
DB.connect()
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass

View File

@ -1,18 +1,13 @@
from contextvars import ContextVar
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, _ConnectionState
from playhouse.db_url import register_database
from peewee import *
from playhouse.db_url import connect
from playhouse.pool import PooledPostgresqlDatabase
from playhouse.shortcuts import ReconnectMixin
from psycopg2 import OperationalError
from psycopg2.errors import InterfaceError
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
class PeeweeConnectionState(_ConnectionState):
class PeeweeConnectionState(object):
def __init__(self, **kwargs):
super().__setattr__("_state", db_state)
super().__init__(**kwargs)
@ -21,29 +16,29 @@ class PeeweeConnectionState(_ConnectionState):
self._state.get()[name] = value
def __getattr__(self, name):
return self._state.get()[name]
value = self._state.get()[name]
return value
class ReconnectingPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase):
pass
class CustomReconnectMixin(ReconnectMixin):
reconnect_errors = (
# default ReconnectMixin exceptions
*ReconnectMixin.reconnect_errors,
# psycopg2
(OperationalError, 'termin'),
(InterfaceError, 'closed'),
# peewee
(PeeWeeInterfaceError, 'closed'),
)
class ReconnectingPooledPostgresqlDatabase(ReconnectMixin, PooledPostgresqlDatabase):
pass
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
class ReconnectingSqliteDatabase(ReconnectMixin, SqliteDatabase):
pass
class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
pass
def register_peewee_databases():
register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')
def register_connection(db_url):
# Connect using the playhouse.db_url module, which supports multiple
# database types, then wrap the connection in a ReconnectMixin to handle dropped connections
db = connect(db_url)
if isinstance(db, PostgresqlDatabase):
db = ReconnectingPostgresqlDatabase(db.database, **db.connect_params)
elif isinstance(db, PooledPostgresqlDatabase):
db = ReconnectingPooledPostgresqlDatabase(db.database, **db.connect_params)
elif isinstance(db, SqliteDatabase):
db = ReconnectingSqliteDatabase(db.database, **db.connect_params)
else:
raise ValueError('Unsupported database connection')
return db