from __future__ import print_function, unicode_literals import json import logging import os import platform import sys from copy import deepcopy from typing import Any, Callable import attr from pathlib2 import Path from pyhocon import ConfigFactory, HOCONConverter, ConfigTree from trains_agent.backend_api.session import Session as _Session, Request from trains_agent.backend_api.session.client import APIClient from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES from trains_agent.definitions import ENVIRONMENT_CONFIG from trains_agent.errors import APIError from trains_agent.helper.base import HOCONEncoder from trains_agent.helper.process import Argv from .version import __version__ POETRY = "poetry" @attr.s class ConfigValue(object): """ Manages a single config key """ config = attr.ib(type=ConfigTree) key = attr.ib(type=str) def get(self, default=None): """ Get value of key with default """ return self.config.get(self.key, default=default) def set(self, value): """ Change the value of key """ self.config.put(self.key, value) def modify(self, fn): # type: (Callable[[Any], Any]) -> () """ Change the value of a key using a function """ self.set(fn(self.get())) def tree(*args): """ Helper function for creating config trees """ return ConfigTree(args) class Session(_Session): version = __version__ def __init__(self, *args, **kwargs): # make sure we set the environment variable so the parent session opens the correct file if kwargs.get('config_file'): config_file = Path(os.path.expandvars(kwargs.get('config_file'))).expanduser().absolute().as_posix() kwargs['config_file'] = config_file os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file if not Path(config_file).is_file(): raise ValueError("Could not open configuration file: {}".format(config_file)) if kwargs.get('only_load_config'): from trains_agent.backend_api.config import load self.config = load() else: super(Session, self).__init__(*args, **kwargs) self.log = self.get_logger(__name__) self.trace = kwargs.get('trace', False) self._config_file = kwargs.get('config_file') or \ os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0] self.api_client = APIClient(session=self, api_version="2.4") # HACK make sure we have python version to execute, # if nothing was specific, use the one that runs us def_python = ConfigValue(self.config, "agent.default_python") if not def_python.get(): def_python.set("{version.major}.{version.minor}".format(version=sys.version_info)) # HACK: backwards compatibility os.environ['ALG_CONFIG_FILE'] = self._config_file os.environ['SM_CONFIG_FILE'] = self._config_file if not self.config.get('api.host', None) and self.config.get('api.api_server', None): self.config['api']['host'] = self.config.get('api.api_server') # initialize nvidia visibility variable os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'): os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES') elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'): os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES') # override with environment variables # cuda_version & cudnn_version are overridden with os.environ here, and normalized in the next section for config_key, env_config in ENVIRONMENT_CONFIG.items(): value = env_config.get() if not value: continue env_key = ConfigValue(self.config, config_key) env_key.set(value) # initialize cuda versions try: from trains_agent.helper.package.requirements import RequirementsManager agent = self.config['agent'] agent['cuda_version'], agent['cudnn_version'] = \ RequirementsManager.get_cuda_version(self.config) except Exception: pass # initialize worker name worker_name = ConfigValue(self.config, "agent.worker_name") if not worker_name.get(): worker_name.set(platform.node()) if not kwargs.get('only_load_config'): self.create_cache_folders() @staticmethod def get_logger(name): logger = logging.getLogger(name) logger.propagate = True return TrainsAgentLogger(logger) @property def debug_mode(self): return self.config.get("agent.debug", False) @property def config_file(self): return self._config_file def create_cache_folders(self, slot_index=0): """ create and update the cache folders notice we support multiple instances sharing the same cache on some folders and on some we use "instance slot" numbers in order to differentiate between the different instances running notice slot_index=0 is the default, meaning no suffix is added to the singleton_folders Note: do not call this function twice with non zero slot_index it will add a suffix to the folders on each call :param slot_index: integer """ # create target folders: folder_keys = ('agent.venvs_dir', 'agent.vcs_cache.path', 'agent.pip_download_cache.path', 'agent.docker_pip_cache', 'agent.docker_apt_cache') singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path',) for key in folder_keys: folder_key = ConfigValue(self.config, key) if not folder_key.get(): continue if slot_index and key in singleton_folders: f = folder_key.get() if f.endswith(os.path.sep): f = f[:-1] folder_key.set(f + '.{}'.format(slot_index)) # update the configuration for full path folder = Path(os.path.expandvars(folder_key.get())).expanduser().absolute() folder_key.set(folder.as_posix()) try: folder.mkdir(parents=True, exist_ok=True) except: pass def print_configuration(self, remove_secret_keys=("secret", "pass", "token", "account_key")): # remove all the secrets from the print def recursive_remove_secrets(dictionary, secret_keys=()): for k in list(dictionary): for s in secret_keys: if s in k: dictionary.pop(k) break if isinstance(dictionary.get(k, None), dict): recursive_remove_secrets(dictionary[k], secret_keys=secret_keys) elif isinstance(dictionary.get(k, None), (list, tuple)): for item in dictionary[k]: if isinstance(item, dict): recursive_remove_secrets(item, secret_keys=secret_keys) config = deepcopy(self.config.to_dict()) # remove the env variable, it's not important config.pop('env', None) if remove_secret_keys: recursive_remove_secrets(config, secret_keys=remove_secret_keys) config = ConfigFactory.from_dict(config) self.log.debug("Run by interpreter: %s", sys.executable) print( "Current configuration (trains_agent v{}, location: {}):\n" "----------------------\n{}\n".format( self.version, self._config_file, HOCONConverter.convert(config, "properties") ) ) def send_api(self, request): # type: (Request) -> Any result = self.send(request) if not result.ok(): raise APIError(result) if not result.response: raise APIError(result, extra_info="Invalid response") return result.response def get(self, service, action, version=None, headers=None, data=None, json=None, async_enable=False, **kwargs): return self._manual_request(service=service, action=action, version=version, method="get", headers=headers, data=data, async_enable=async_enable, json=json or kwargs) def post(self, service, action, version=None, headers=None, data=None, json=None, async_enable=False, **kwargs): return self._manual_request(service=service, action=action, version=version, method="post", headers=headers, data=data, async_enable=async_enable, json=json or kwargs) def _manual_request(self, service, action, version=None, method="get", headers=None, data=None, json=None, async_enable=False, **kwargs): res = self.send_request(service=service, action=action, version=version, method=method, headers=headers, data=data, async_enable=async_enable, json=json or kwargs) try: res_json = res.json() return_code = res_json["meta"]["result_code"] except (ValueError, KeyError, TypeError): raise APIError(res) # check return code if return_code != 200: raise APIError(res) return res_json["data"] def to_json(self): return json.dumps( self.config.as_plain_ordered_dict(), cls=HOCONEncoder, indent=4 ) def command(self, *args): return Argv(*args, log=self.get_logger(Argv.__module__)) @attr.s class TrainsAgentLogger(object): """ Proxy around logging.Logger because inheriting from it is difficult. """ logger = attr.ib(type=logging.Logger) def _log_with_error(self, level, *args, **kwargs): """ Include error information when in debug mode """ kwargs.setdefault("exc_info", self.logger.isEnabledFor(logging.DEBUG)) return self.logger.log(level, *args, **kwargs) def warning(self, *args, **kwargs): return self._log_with_error(logging.WARNING, *args, **kwargs) def error(self, *args, **kwargs): return self._log_with_error(logging.ERROR, *args, **kwargs) def __getattr__(self, item): return getattr(self.logger, item) def __call__(self, *args, **kwargs): """ Compatibility with old ``Command.log()`` method """ return self.logger.info(*args, **kwargs) def normalize_cuda_version(value): # type: (Any) -> str """ Take variably formatted cuda version string/number and return it in the same format: string decimal representation of 10 * major + minor. >>> normalize_cuda_version(100) '100' >>> normalize_cuda_version("100") '100' >>> normalize_cuda_version(10) '10' >>> normalize_cuda_version(10.0) '100' >>> normalize_cuda_version("10.0") '100' >>> normalize_cuda_version("10.0.130") '100' """ value = str(value) if "." in value: try: value = str(int(float(".".join(value.split(".")[:2])) * 10)) except (ValueError, TypeError): pass return value