mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 02:33:02 +00:00
Refactor database into a separate class
This commit is contained in:
parent
c7bbac73d0
commit
b8e62f27e2
@ -4,7 +4,7 @@ from hashlib import md5
|
||||
from flask import Flask
|
||||
from semantic_version import Version
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.database import database
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
from apiserver.config import config, info
|
||||
from apiserver.elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
|
||||
|
@ -6,7 +6,7 @@ from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection
|
||||
from mongoengine.connection import get_connection, disconnect
|
||||
|
||||
from apiserver.config import config
|
||||
from .defs import Database
|
||||
@ -23,70 +23,76 @@ OVERRIDE_HOST_ENV_KEY = (
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
|
||||
_entries = []
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
@property
|
||||
def health_alias(self):
|
||||
return "__health__" + self.alias
|
||||
|
||||
class DatabaseFactory:
|
||||
_entries = []
|
||||
|
||||
def initialize():
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
_entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
|
||||
cls._entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
|
||||
@classmethod
|
||||
def get_entries(cls):
|
||||
return cls._entries
|
||||
|
||||
@classmethod
|
||||
def get_hosts(cls):
|
||||
return [entry.host for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def get_aliases(cls):
|
||||
return [entry.alias for entry in cls.get_entries()]
|
||||
|
||||
@classmethod
|
||||
def reconnect(cls):
|
||||
for entry in cls.get_entries():
|
||||
# there is bug in the current implementation that prevents
|
||||
# reconnection from work so workaround this
|
||||
# get_connection(entry.alias, reconnect=True)
|
||||
disconnect(entry.alias)
|
||||
register_connection(alias=entry.alias, host=entry.host)
|
||||
get_connection(entry.alias)
|
||||
|
||||
|
||||
def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
def reconnect():
|
||||
for entry in get_entries():
|
||||
get_connection(entry.alias, reconnect=True)
|
||||
database = DatabaseFactory()
|
||||
|
@ -6,8 +6,8 @@ from humanfriendly import parse_timespan
|
||||
|
||||
|
||||
def setup():
|
||||
from apiserver.database import initialize
|
||||
initialize()
|
||||
from apiserver.database import database
|
||||
database.initialize()
|
||||
|
||||
|
||||
def gen_token(args):
|
||||
|
Loading…
Reference in New Issue
Block a user