From b8e62f27e2c8503b766e75d568455f1f149ca4e0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 16:31:25 +0200 Subject: [PATCH] Refactor database into a separate class --- apiserver/app_sequence.py | 2 +- apiserver/database/__init__.py | 112 +++++++++++++++++---------------- apiserver/tools.py | 4 +- 3 files changed, 62 insertions(+), 56 deletions(-) diff --git a/apiserver/app_sequence.py b/apiserver/app_sequence.py index 7a5163e..2ed2a0c 100644 --- a/apiserver/app_sequence.py +++ b/apiserver/app_sequence.py @@ -4,7 +4,7 @@ from hashlib import md5 from flask import Flask from semantic_version import Version -from apiserver import database +from apiserver.database import database from apiserver.bll.statistics.stats_reporter import StatisticsReporter from apiserver.config import config, info from apiserver.elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError diff --git a/apiserver/database/__init__.py b/apiserver/database/__init__.py index 78209f3..52be3a8 100644 --- a/apiserver/database/__init__.py +++ b/apiserver/database/__init__.py @@ -6,7 +6,7 @@ from jsonmodels import models from jsonmodels.errors import ValidationError from jsonmodels.fields import StringField from mongoengine import register_connection -from mongoengine.connection import get_connection +from mongoengine.connection import get_connection, disconnect from apiserver.config import config from .defs import Database @@ -23,70 +23,76 @@ OVERRIDE_HOST_ENV_KEY = ( ) OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT") -_entries = [] - class DatabaseEntry(models.Base): host = StringField(required=True) alias = StringField() - @property - def health_alias(self): - return "__health__" + self.alias +class DatabaseFactory: + _entries = [] -def initialize(): - db_entries = config.get("hosts.mongo", {}) - missing = [] - log.info("Initializing database connections") - - override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None) - if override_hostname: - log.info(f"Using override mongodb host {override_hostname}") - - override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None) - if override_port: - log.info(f"Using override mongodb port {override_port}") - - for key, alias in get_items(Database).items(): - if key not in db_entries: - missing.append(key) - continue - - entry = DatabaseEntry(alias=alias, **db_entries.get(key)) + @classmethod + def initialize(cls): + db_entries = config.get("hosts.mongo", {}) + missing = [] + log.info("Initializing database connections") + override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None) if override_hostname: - entry.host = furl(entry.host).set(host=override_hostname).url + log.info(f"Using override mongodb host {override_hostname}") + override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None) if override_port: - entry.host = furl(entry.host).set(port=override_port).url + log.info(f"Using override mongodb port {override_port}") - try: - entry.validate() - log.info( - "Registering connection to %(alias)s (%(host)s)" % entry.to_struct() - ) - register_connection(alias=alias, host=entry.host) + for key, alias in get_items(Database).items(): + if key not in db_entries: + missing.append(key) + continue - _entries.append(entry) - except ValidationError as ex: - raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0])) - if missing: - raise ValueError("Missing database configuration for %s" % ", ".join(missing)) + entry = DatabaseEntry(alias=alias, **db_entries.get(key)) + + if override_hostname: + entry.host = furl(entry.host).set(host=override_hostname).url + + if override_port: + entry.host = furl(entry.host).set(port=override_port).url + + try: + entry.validate() + log.info( + "Registering connection to %(alias)s (%(host)s)" % entry.to_struct() + ) + register_connection(alias=alias, host=entry.host) + + cls._entries.append(entry) + except ValidationError as ex: + raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0])) + if missing: + raise ValueError("Missing database configuration for %s" % ", ".join(missing)) + + @classmethod + def get_entries(cls): + return cls._entries + + @classmethod + def get_hosts(cls): + return [entry.host for entry in cls.get_entries()] + + @classmethod + def get_aliases(cls): + return [entry.alias for entry in cls.get_entries()] + + @classmethod + def reconnect(cls): + for entry in cls.get_entries(): + # there is bug in the current implementation that prevents + # reconnection from work so workaround this + # get_connection(entry.alias, reconnect=True) + disconnect(entry.alias) + register_connection(alias=entry.alias, host=entry.host) + get_connection(entry.alias) -def get_entries(): - return _entries - - -def get_hosts(): - return [entry.host for entry in get_entries()] - - -def get_aliases(): - return [entry.alias for entry in get_entries()] - - -def reconnect(): - for entry in get_entries(): - get_connection(entry.alias, reconnect=True) +database = DatabaseFactory() diff --git a/apiserver/tools.py b/apiserver/tools.py index 9895c71..2e234c5 100644 --- a/apiserver/tools.py +++ b/apiserver/tools.py @@ -6,8 +6,8 @@ from humanfriendly import parse_timespan def setup(): - from apiserver.database import initialize - initialize() + from apiserver.database import database + database.initialize() def gen_token(args):