mirror of
https://github.com/open-webui/open-webui
synced 2025-03-24 06:37:14 +00:00
Merge pull request #3221 from perfectra1n/feature-external-db-reconnect
feat: external db reconnect
This commit is contained in:
commit
1e0453221d
20
.github/workflows/integration-test.yml
vendored
20
.github/workflows/integration-test.yml
vendored
@ -170,6 +170,26 @@ jobs:
|
||||
echo "Server has stopped"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check that service will reconnect to postgres when connection will be closed
|
||||
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags)
|
||||
if [[ "$status_code" -ne 200 ]] ; then
|
||||
echo "Server has failed before postgres reconnect check"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Terminating all connections to postgres..."
|
||||
python -c "import os, psycopg2 as pg2; \
|
||||
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
|
||||
cur = conn.cursor(); \
|
||||
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
|
||||
|
||||
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/api/tags)
|
||||
if [[ "$status_code" -ne 200 ]] ; then
|
||||
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# - name: Test backend with MySQL
|
||||
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
|
||||
|
@ -1,16 +1,16 @@
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
|
||||
from peewee import *
|
||||
from peewee_migrate import Router
|
||||
from playhouse.db_url import connect
|
||||
|
||||
from apps.webui.internal.wrappers import register_connection
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
|
||||
import os
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
|
||||
class JSONField(TextField):
|
||||
def db_value(self, value):
|
||||
return json.dumps(value)
|
||||
@ -19,7 +19,6 @@ class JSONField(TextField):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
# Check if the file exists
|
||||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
# Rename the file
|
||||
@ -28,12 +27,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
else:
|
||||
pass
|
||||
|
||||
DB = connect(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
|
||||
# The `register_connection` function encapsulates the logic for setting up
|
||||
# the database connection based on the connection string, while `connect`
|
||||
# is a Peewee-specific method to manage the connection state and avoid errors
|
||||
# when a connection is already open.
|
||||
try:
|
||||
DB = register_connection(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
raise
|
||||
|
||||
router = Router(
|
||||
DB,
|
||||
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
|
||||
logger=log,
|
||||
)
|
||||
router.run()
|
||||
DB.connect(reuse_if_open=True)
|
||||
try:
|
||||
DB.connect(reuse_if_open=True)
|
||||
except OperationalError as e:
|
||||
log.info(f"Failed to connect to database again due to: {e}")
|
||||
pass
|
72
backend/apps/webui/internal/wrappers.py
Normal file
72
backend/apps/webui/internal/wrappers.py
Normal file
@ -0,0 +1,72 @@
|
||||
from contextvars import ContextVar
|
||||
from peewee import *
|
||||
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
|
||||
|
||||
import logging
|
||||
from playhouse.db_url import connect, parse
|
||||
from playhouse.shortcuts import ReconnectMixin
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||
|
||||
|
||||
class PeeweeConnectionState(object):
|
||||
def __init__(self, **kwargs):
|
||||
super().__setattr__("_state", db_state)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self._state.get()[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
value = self._state.get()[name]
|
||||
return value
|
||||
|
||||
|
||||
class CustomReconnectMixin(ReconnectMixin):
|
||||
reconnect_errors = (
|
||||
# psycopg2
|
||||
(OperationalError, "termin"),
|
||||
(InterfaceError, "closed"),
|
||||
# peewee
|
||||
(PeeWeeInterfaceError, "closed"),
|
||||
)
|
||||
|
||||
|
||||
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||
pass
|
||||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(
|
||||
connection["database"],
|
||||
user=connection["user"],
|
||||
password=connection["password"],
|
||||
host=connection["host"],
|
||||
port=connection["port"],
|
||||
)
|
||||
db.connect(reuse_if_open=True)
|
||||
elif isinstance(db, SqliteDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
return db
|
Loading…
Reference in New Issue
Block a user