Refactor database into a separate class

This commit is contained in:
allegroai 2021-01-05 16:31:25 +02:00
parent c7bbac73d0
commit b8e62f27e2
3 changed files with 62 additions and 56 deletions

View File

@ -4,7 +4,7 @@ from hashlib import md5
from flask import Flask
from semantic_version import Version
from apiserver import database
from apiserver.database import database
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
from apiserver.config import config, info
from apiserver.elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError

View File

@ -6,7 +6,7 @@ from jsonmodels import models
from jsonmodels.errors import ValidationError
from jsonmodels.fields import StringField
from mongoengine import register_connection
from mongoengine.connection import get_connection
from mongoengine.connection import get_connection, disconnect
from apiserver.config import config
from .defs import Database
@ -23,70 +23,76 @@ OVERRIDE_HOST_ENV_KEY = (
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
_entries = []
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
@property
def health_alias(self):
return "__health__" + self.alias
class DatabaseFactory:
_entries = []
def initialize():
db_entries = config.get("hosts.mongo", {})
missing = []
log.info("Initializing database connections")
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
@classmethod
def initialize(cls):
db_entries = config.get("hosts.mongo", {})
missing = []
log.info("Initializing database connections")
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
log.info(f"Using override mongodb port {override_port}")
try:
entry.validate()
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
_entries.append(entry)
except ValidationError as ex:
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
try:
entry.validate()
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
cls._entries.append(entry)
except ValidationError as ex:
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
@classmethod
def get_entries(cls):
return cls._entries
@classmethod
def get_hosts(cls):
return [entry.host for entry in cls.get_entries()]
@classmethod
def get_aliases(cls):
return [entry.alias for entry in cls.get_entries()]
@classmethod
def reconnect(cls):
for entry in cls.get_entries():
# there is bug in the current implementation that prevents
# reconnection from work so workaround this
# get_connection(entry.alias, reconnect=True)
disconnect(entry.alias)
register_connection(alias=entry.alias, host=entry.host)
get_connection(entry.alias)
def get_entries():
return _entries
def get_hosts():
return [entry.host for entry in get_entries()]
def get_aliases():
return [entry.alias for entry in get_entries()]
def reconnect():
for entry in get_entries():
get_connection(entry.alias, reconnect=True)
database = DatabaseFactory()

View File

@ -6,8 +6,8 @@ from humanfriendly import parse_timespan
def setup():
from apiserver.database import initialize
initialize()
from apiserver.database import database
database.initialize()
def gen_token(args):