mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
fix peewee and playhouse connections to retry
This commit is contained in:
parent
75d713057c
commit
10fa887eab
@ -4,15 +4,13 @@ import json
|
|||||||
|
|
||||||
from peewee import *
|
from peewee import *
|
||||||
from peewee_migrate import Router
|
from peewee_migrate import Router
|
||||||
from playhouse.db_url import connect
|
|
||||||
|
|
||||||
from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases
|
from apps.webui.internal.wrappers import register_connection
|
||||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
|
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
|
||||||
|
|
||||||
class JSONField(TextField):
|
class JSONField(TextField):
|
||||||
def db_value(self, value):
|
def db_value(self, value):
|
||||||
return json.dumps(value)
|
return json.dumps(value)
|
||||||
@ -21,9 +19,6 @@ class JSONField(TextField):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
return json.loads(value)
|
return json.loads(value)
|
||||||
|
|
||||||
|
|
||||||
register_peewee_databases()
|
|
||||||
|
|
||||||
# Check if the file exists
|
# Check if the file exists
|
||||||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||||
# Rename the file
|
# Rename the file
|
||||||
@ -32,13 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
DB = connect(DATABASE_URL)
|
|
||||||
DB._state = PeeweeConnectionState()
|
# The `register_connection` function encapsulates the logic for setting up
|
||||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
# the database connection based on the connection string, while `connect`
|
||||||
|
# is a Peewee-specific method to manage the connection state and avoid errors
|
||||||
|
# when a connection is already open.
|
||||||
|
try:
|
||||||
|
DB = register_connection(DATABASE_URL)
|
||||||
|
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to initialize the database connection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
router = Router(
|
router = Router(
|
||||||
DB,
|
DB,
|
||||||
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
|
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
|
||||||
logger=log,
|
logger=log,
|
||||||
)
|
)
|
||||||
router.run()
|
router.run()
|
||||||
DB.connect(reuse_if_open=True)
|
try:
|
||||||
|
DB.connect()
|
||||||
|
except OperationalError as e:
|
||||||
|
log.info(f"Failed to connect to database again due to: {e}")
|
||||||
|
pass
|
@ -1,18 +1,13 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from peewee import *
|
||||||
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, _ConnectionState
|
from playhouse.db_url import connect
|
||||||
from playhouse.db_url import register_database
|
|
||||||
from playhouse.pool import PooledPostgresqlDatabase
|
from playhouse.pool import PooledPostgresqlDatabase
|
||||||
from playhouse.shortcuts import ReconnectMixin
|
from playhouse.shortcuts import ReconnectMixin
|
||||||
from psycopg2 import OperationalError
|
|
||||||
from psycopg2.errors import InterfaceError
|
|
||||||
|
|
||||||
|
|
||||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||||
|
|
||||||
|
class PeeweeConnectionState(object):
|
||||||
class PeeweeConnectionState(_ConnectionState):
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__setattr__("_state", db_state)
|
super().__setattr__("_state", db_state)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -21,29 +16,29 @@ class PeeweeConnectionState(_ConnectionState):
|
|||||||
self._state.get()[name] = value
|
self._state.get()[name] = value
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return self._state.get()[name]
|
value = self._state.get()[name]
|
||||||
|
return value
|
||||||
|
|
||||||
|
class ReconnectingPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase):
|
||||||
|
pass
|
||||||
|
|
||||||
class CustomReconnectMixin(ReconnectMixin):
|
class ReconnectingPooledPostgresqlDatabase(ReconnectMixin, PooledPostgresqlDatabase):
|
||||||
reconnect_errors = (
|
pass
|
||||||
# default ReconnectMixin exceptions
|
|
||||||
*ReconnectMixin.reconnect_errors,
|
|
||||||
# psycopg2
|
|
||||||
(OperationalError, 'termin'),
|
|
||||||
(InterfaceError, 'closed'),
|
|
||||||
# peewee
|
|
||||||
(PeeWeeInterfaceError, 'closed'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class ReconnectingSqliteDatabase(ReconnectMixin, SqliteDatabase):
|
||||||
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase):
|
def register_connection(db_url):
|
||||||
pass
|
# Connect using the playhouse.db_url module, which supports multiple
|
||||||
|
# database types, then wrap the connection in a ReconnectMixin to handle dropped connections
|
||||||
|
db = connect(db_url)
|
||||||
def register_peewee_databases():
|
if isinstance(db, PostgresqlDatabase):
|
||||||
register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql')
|
db = ReconnectingPostgresqlDatabase(db.database, **db.connect_params)
|
||||||
register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool')
|
elif isinstance(db, PooledPostgresqlDatabase):
|
||||||
|
db = ReconnectingPooledPostgresqlDatabase(db.database, **db.connect_params)
|
||||||
|
elif isinstance(db, SqliteDatabase):
|
||||||
|
db = ReconnectingSqliteDatabase(db.database, **db.connect_params)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unsupported database connection')
|
||||||
|
return db
|
||||||
|
Loading…
Reference in New Issue
Block a user