from os import getenv from boltons.iterutils import first from furl import furl from jsonmodels import models from jsonmodels.errors import ValidationError from jsonmodels.fields import StringField from mongoengine import register_connection from mongoengine.connection import get_connection, disconnect from apiserver.config_repo import config from .defs import Database from .utils import get_items log = config.logger("database") strict = config.get("apiserver.mongo.strict", True) OVERRIDE_HOST_ENV_KEY = ( "CLEARML_MONGODB_SERVICE_HOST", "TRAINS_MONGODB_SERVICE_HOST", "MONGODB_SERVICE_HOST", "MONGODB_SERVICE_SERVICE_HOST", ) OVERRIDE_PORT_ENV_KEY = ( "CLEARML_MONGODB_SERVICE_PORT", "TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT", ) OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING" class DatabaseEntry(models.Base): host = StringField(required=True) alias = StringField() class DatabaseFactory: _entries = [] @classmethod def _create_db_entry(cls, alias: str, settings: dict) -> DatabaseEntry: return DatabaseEntry(alias=alias, **settings) @classmethod def initialize(cls): db_entries = config.get("hosts.mongo", {}) missing = [] log.info("Initializing database connections") override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY) override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None) override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None) if override_connection_string: log.info(f"Using override mongodb connection string {override_connection_string}") else: if override_hostname: log.info(f"Using override mongodb host {override_hostname}") if override_port: log.info(f"Using override mongodb port {override_port}") for key, alias in get_items(Database).items(): if key not in db_entries: missing.append(key) continue entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key)) if override_connection_string: entry.host = override_connection_string else: if override_hostname: entry.host = furl(entry.host).set(host=override_hostname).url if override_port: entry.host = furl(entry.host).set(port=override_port).url try: entry.validate() log.info( "Registering connection to %(alias)s (%(host)s)" % entry.to_struct() ) register_connection(**entry.to_struct()) cls._entries.append(entry) except ValidationError as ex: raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0])) if missing: raise ValueError( "Missing database configuration for %s" % ", ".join(missing) ) @classmethod def get_entries(cls): return cls._entries @classmethod def get_hosts(cls): return [entry.host for entry in cls.get_entries()] @classmethod def get_aliases(cls): return [entry.alias for entry in cls.get_entries()] @classmethod def reconnect(cls): for entry in cls.get_entries(): # there is bug in the current implementation that prevents # reconnection from work so workaround this # get_connection(entry.alias, reconnect=True) disconnect(entry.alias) register_connection(**entry.to_struct()) get_connection(entry.alias) db = DatabaseFactory()