From a6a0b01f7164bbf1e9581abd652557c767184618 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 1 Apr 2020 19:11:37 +0300 Subject: [PATCH] Remove deprecated OS environment variables --- trains_agent/backend_api/session/session.py | 3 +++ trains_agent/backend_config/environment.py | 28 +++++++++++++++++++++ trains_agent/commands/worker.py | 2 +- trains_agent/definitions.py | 28 ++++++++++----------- trains_agent/session.py | 8 +++--- 5 files changed, 51 insertions(+), 18 deletions(-) diff --git a/trains_agent/backend_api/session/session.py b/trains_agent/backend_api/session/session.py index 34ceb45..05b464a 100644 --- a/trains_agent/backend_api/session/session.py +++ b/trains_agent/backend_api/session/session.py @@ -16,6 +16,7 @@ from .request import Request, BatchRequest from .token_manager import TokenManager from ..config import load from ..utils import get_http_session_with_retry, urllib_log_warning_setup +from ...backend_config.environment import backward_compatibility_support from ...version import __version__ @@ -86,6 +87,8 @@ class Session(TokenManager): config=None, **kwargs ): + # add backward compatibility support for old environment variables + backward_compatibility_support() if config is not None: self.config = config diff --git a/trains_agent/backend_config/environment.py b/trains_agent/backend_config/environment.py index 30ca80b..57b1eca 100644 --- a/trains_agent/backend_config/environment.py +++ b/trains_agent/backend_config/environment.py @@ -23,3 +23,31 @@ class EnvEntry(Entry): def error(self, message): print("Environment configuration: {}".format(message)) + + +def backward_compatibility_support(): + from ..definitions import ENVIRONMENT_CONFIG, ENVIRONMENT_SDK_PARAMS, ENVIRONMENT_BACKWARD_COMPATIBLE + if not ENVIRONMENT_BACKWARD_COMPATIBLE.get(): + return + + # Add ALG_ prefix on every TRAINS_ os environment we support + for k, v in ENVIRONMENT_CONFIG.items(): + try: + trains_vars = [var for var in v.vars if var.startswith('TRAINS_')] + if not trains_vars: + continue + alg_var = trains_vars[0].replace('TRAINS_', 'ALG_', 1) + if alg_var not in v.vars: + v.vars = tuple(list(v.vars) + [alg_var]) + except: + continue + for k, v in ENVIRONMENT_SDK_PARAMS.items(): + try: + trains_vars = [var for var in v if var.startswith('TRAINS_')] + if not trains_vars: + continue + alg_var = trains_vars[0].replace('TRAINS_', 'ALG_', 1) + if alg_var not in v: + ENVIRONMENT_SDK_PARAMS[k] = tuple(list(v) + [alg_var]) + except: + continue diff --git a/trains_agent/commands/worker.py b/trains_agent/commands/worker.py index b460f65..a7fbe2a 100644 --- a/trains_agent/commands/worker.py +++ b/trains_agent/commands/worker.py @@ -1817,7 +1817,7 @@ class Worker(ServiceCommandSection): args.update(kwargs) return self._get_docker_cmd(**args) - docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or os.environ.get("ALG_DOCKER_IMAGE") or + docker_image = str(os.environ.get("TRAINS_DOCKER_IMAGE") or self._session.config.get("agent.default_docker.image", "nvidia/cuda")) \ if not docker_args else docker_args[0] docker_arguments = docker_image.split(' ') if docker_image else [] diff --git a/trains_agent/definitions.py b/trains_agent/definitions.py index 11b8e08..8f29832 100644 --- a/trains_agent/definitions.py +++ b/trains_agent/definitions.py @@ -55,23 +55,23 @@ class EnvironmentConfig(object): ENVIRONMENT_CONFIG = { - "api.api_server": EnvironmentConfig("TRAINS_API_HOST", "ALG_API_HOST"), + "api.api_server": EnvironmentConfig("TRAINS_API_HOST", ), "api.credentials.access_key": EnvironmentConfig( - "TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY" + "TRAINS_API_ACCESS_KEY", ), "api.credentials.secret_key": EnvironmentConfig( - "TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY" + "TRAINS_API_SECRET_KEY", ), - "agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", "ALG_WORKER_NAME"), - "agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", "ALG_WORKER_ID"), + "agent.worker_name": EnvironmentConfig("TRAINS_WORKER_NAME", ), + "agent.worker_id": EnvironmentConfig("TRAINS_WORKER_ID", ), "agent.cuda_version": EnvironmentConfig( - "TRAINS_CUDA_VERSION", "ALG_CUDA_VERSION", "CUDA_VERSION" + "TRAINS_CUDA_VERSION", "CUDA_VERSION" ), "agent.cudnn_version": EnvironmentConfig( - "TRAINS_CUDNN_VERSION", "ALG_CUDNN_VERSION", "CUDNN_VERSION" + "TRAINS_CUDNN_VERSION", "CUDNN_VERSION" ), "agent.cpu_only": EnvironmentConfig( - "TRAINS_CPU_ONLY", "ALG_CPU_ONLY", "CPU_ONLY", type=bool + "TRAINS_CPU_ONLY", "CPU_ONLY", type=bool ), "sdk.aws.s3.key": EnvironmentConfig("AWS_ACCESS_KEY_ID"), "sdk.aws.s3.secret": EnvironmentConfig("AWS_SECRET_ACCESS_KEY"), @@ -81,15 +81,15 @@ ENVIRONMENT_CONFIG = { "sdk.google.storage.credentials_json": EnvironmentConfig("GOOGLE_APPLICATION_CREDENTIALS"), } -CONFIG_FILE_ENV = EnvironmentConfig("ALG_CONFIG_FILE") - ENVIRONMENT_SDK_PARAMS = { - "task_id": ("TRAINS_TASK_ID", "ALG_TASK_ID"), - "config_file": ("TRAINS_CONFIG_FILE", "ALG_CONFIG_FILE", "TRAINS_CONFIG_FILE"), - "log_level": ("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL"), - "log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND"), + "task_id": ("TRAINS_TASK_ID", ), + "config_file": ("TRAINS_CONFIG_FILE", ), + "log_level": ("TRAINS_LOG_LEVEL", ), + "log_to_backend": ("TRAINS_LOG_TASK_TO_BACKEND", ), } +ENVIRONMENT_BACKWARD_COMPATIBLE = EnvironmentConfig("TRAINS_AGENT_ALG_ENV", type=bool) + VIRTUAL_ENVIRONMENT_PATH = { "python2": normalize_path(CONFIG_DIR, "py2venv"), "python3": normalize_path(CONFIG_DIR, "py3venv"), diff --git a/trains_agent/session.py b/trains_agent/session.py index 19ce98e..44164ab 100644 --- a/trains_agent/session.py +++ b/trains_agent/session.py @@ -15,7 +15,7 @@ from pyhocon import ConfigFactory, HOCONConverter, ConfigTree from trains_agent.backend_api.session import Session as _Session, Request from trains_agent.backend_api.session.client import APIClient from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES -from trains_agent.definitions import ENVIRONMENT_CONFIG, ENV_TASK_EXECUTE_AS_USER +from trains_agent.definitions import ENVIRONMENT_CONFIG, ENV_TASK_EXECUTE_AS_USER, ENVIRONMENT_BACKWARD_COMPATIBLE from trains_agent.errors import APIError from trains_agent.helper.base import HOCONEncoder from trains_agent.helper.process import Argv @@ -95,8 +95,10 @@ class Session(_Session): def_python.set("{version.major}.{version.minor}".format(version=sys.version_info)) # HACK: backwards compatibility - os.environ['ALG_CONFIG_FILE'] = self._config_file - os.environ['SM_CONFIG_FILE'] = self._config_file + if ENVIRONMENT_BACKWARD_COMPATIBLE.get(): + os.environ['ALG_CONFIG_FILE'] = self._config_file + os.environ['SM_CONFIG_FILE'] = self._config_file + if not self.config.get('api.host', None) and self.config.get('api.api_server', None): self.config['api']['host'] = self.config.get('api.api_server')