diff --git a/apiserver/database/__init__.py b/apiserver/database/__init__.py index d35710e..b962db2 100644 --- a/apiserver/database/__init__.py +++ b/apiserver/database/__init__.py @@ -28,6 +28,8 @@ OVERRIDE_PORT_ENV_KEY = ( "MONGODB_SERVICE_PORT", ) +OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING" + class DatabaseEntry(models.Base): host = StringField(required=True) @@ -47,14 +49,18 @@ class DatabaseFactory: missing = [] log.info("Initializing database connections") + override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY) override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None) - if override_hostname: - log.info(f"Using override mongodb host {override_hostname}") - override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None) - if override_port: - log.info(f"Using override mongodb port {override_port}") + if override_connection_string: + log.info(f"Using override mongodb connection string {override_connection_string}") + else: + if override_hostname: + log.info(f"Using override mongodb host {override_hostname}") + if override_port: + log.info(f"Using override mongodb port {override_port}") + for key, alias in get_items(Database).items(): if key not in db_entries: missing.append(key) @@ -62,11 +68,13 @@ class DatabaseFactory: entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key)) - if override_hostname: - entry.host = furl(entry.host).set(host=override_hostname).url - - if override_port: - entry.host = furl(entry.host).set(port=override_port).url + if override_connection_string: + entry.host = override_connection_string + else: + if override_hostname: + entry.host = furl(entry.host).set(host=override_hostname).url + if override_port: + entry.host = furl(entry.host).set(port=override_port).url try: entry.validate()