mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
enh: handle peewee migration
This commit is contained in:
parent
d60f06608e
commit
1436bb7c61
@ -2,6 +2,10 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from peewee_migrate import Router
|
||||||
|
from apps.webui.internal.wrappers import register_connection
|
||||||
|
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Workaround to handle the peewee migration
|
||||||
|
# This is required to ensure the peewee migration is handled before the alembic migration
|
||||||
|
def handle_peewee_migration():
|
||||||
|
try:
|
||||||
|
db = register_connection(DATABASE_URL)
|
||||||
|
migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
|
||||||
|
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||||
|
router.run()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# check if db connection has been closed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to initialize the database connection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Properly closing the database connection
|
||||||
|
if db and not db.is_closed():
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Assert if db connection has been closed
|
||||||
|
assert db.is_closed(), "Database connection is still open."
|
||||||
|
|
||||||
|
|
||||||
|
handle_peewee_migration()
|
||||||
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||||
if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@ -62,9 +95,6 @@ Base = declarative_base()
|
|||||||
Session = scoped_session(SessionLocal)
|
Session = scoped_session(SessionLocal)
|
||||||
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
|
|
||||||
# Dependency
|
# Dependency
|
||||||
def get_session():
|
def get_session():
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
|
72
backend/apps/webui/internal/wrappers.py
Normal file
72
backend/apps/webui/internal/wrappers.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
from peewee import *
|
||||||
|
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from playhouse.db_url import connect, parse
|
||||||
|
from playhouse.shortcuts import ReconnectMixin
|
||||||
|
|
||||||
|
from config import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||||
|
|
||||||
|
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||||
|
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||||
|
|
||||||
|
|
||||||
|
class PeeweeConnectionState(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__setattr__("_state", db_state)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
self._state.get()[name] = value
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
value = self._state.get()[name]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class CustomReconnectMixin(ReconnectMixin):
|
||||||
|
reconnect_errors = (
|
||||||
|
# psycopg2
|
||||||
|
(OperationalError, "termin"),
|
||||||
|
(InterfaceError, "closed"),
|
||||||
|
# peewee
|
||||||
|
(PeeWeeInterfaceError, "closed"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def register_connection(db_url):
|
||||||
|
db = connect(db_url)
|
||||||
|
if isinstance(db, PostgresqlDatabase):
|
||||||
|
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||||
|
db.autoconnect = True
|
||||||
|
db.reuse_if_open = True
|
||||||
|
log.info("Connected to PostgreSQL database")
|
||||||
|
|
||||||
|
# Get the connection details
|
||||||
|
connection = parse(db_url)
|
||||||
|
|
||||||
|
# Use our custom database class that supports reconnection
|
||||||
|
db = ReconnectingPostgresqlDatabase(
|
||||||
|
connection["database"],
|
||||||
|
user=connection["user"],
|
||||||
|
password=connection["password"],
|
||||||
|
host=connection["host"],
|
||||||
|
port=connection["port"],
|
||||||
|
)
|
||||||
|
db.connect(reuse_if_open=True)
|
||||||
|
elif isinstance(db, SqliteDatabase):
|
||||||
|
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||||
|
db.autoconnect = True
|
||||||
|
db.reuse_if_open = True
|
||||||
|
log.info("Connected to SQLite database")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported database connection")
|
||||||
|
return db
|
Loading…
Reference in New Issue
Block a user