2019-10-25 19:28:44 +00:00
|
|
|
from __future__ import print_function, unicode_literals
|
|
|
|
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import platform
|
|
|
|
import sys
|
|
|
|
from copy import deepcopy
|
|
|
|
from typing import Any, Callable
|
|
|
|
|
|
|
|
import attr
|
|
|
|
from pathlib2 import Path
|
|
|
|
from pyhocon import ConfigFactory, HOCONConverter, ConfigTree
|
|
|
|
|
|
|
|
from trains_agent.backend_api.session import Session as _Session, Request
|
|
|
|
from trains_agent.backend_api.session.client import APIClient
|
|
|
|
from trains_agent.backend_config.defs import LOCAL_CONFIG_FILE_OVERRIDE_VAR, LOCAL_CONFIG_FILES
|
|
|
|
from trains_agent.definitions import ENVIRONMENT_CONFIG
|
|
|
|
from trains_agent.errors import APIError
|
|
|
|
from trains_agent.helper.base import HOCONEncoder
|
|
|
|
from trains_agent.helper.process import Argv
|
|
|
|
from .version import __version__
|
|
|
|
|
|
|
|
POETRY = "poetry"
|
|
|
|
|
|
|
|
|
|
|
|
@attr.s
|
|
|
|
class ConfigValue(object):
|
|
|
|
|
|
|
|
"""
|
|
|
|
Manages a single config key
|
|
|
|
"""
|
|
|
|
|
|
|
|
config = attr.ib(type=ConfigTree)
|
|
|
|
key = attr.ib(type=str)
|
|
|
|
|
|
|
|
def get(self, default=None):
|
|
|
|
"""
|
|
|
|
Get value of key with default
|
|
|
|
"""
|
|
|
|
return self.config.get(self.key, default=default)
|
|
|
|
|
|
|
|
def set(self, value):
|
|
|
|
"""
|
|
|
|
Change the value of key
|
|
|
|
"""
|
|
|
|
self.config.put(self.key, value)
|
|
|
|
|
|
|
|
def modify(self, fn):
|
|
|
|
# type: (Callable[[Any], Any]) -> ()
|
|
|
|
"""
|
|
|
|
Change the value of a key using a function
|
|
|
|
"""
|
|
|
|
self.set(fn(self.get()))
|
|
|
|
|
|
|
|
|
|
|
|
def tree(*args):
|
|
|
|
"""
|
|
|
|
Helper function for creating config trees
|
|
|
|
"""
|
|
|
|
return ConfigTree(args)
|
|
|
|
|
|
|
|
|
|
|
|
class Session(_Session):
|
|
|
|
version = __version__
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
# make sure we set the environment variable so the parent session opens the correct file
|
|
|
|
if kwargs.get('config_file'):
|
|
|
|
config_file = Path(os.path.expandvars(kwargs.get('config_file'))).expanduser().absolute().as_posix()
|
|
|
|
kwargs['config_file'] = config_file
|
|
|
|
os.environ[LOCAL_CONFIG_FILE_OVERRIDE_VAR] = config_file
|
|
|
|
if not Path(config_file).is_file():
|
|
|
|
raise ValueError("Could not open configuration file: {}".format(config_file))
|
2019-11-08 20:36:24 +00:00
|
|
|
cpu_only = kwargs.get('cpu_only')
|
|
|
|
if cpu_only:
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = 'none'
|
|
|
|
if kwargs.get('gpus'):
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['NVIDIA_VISIBLE_DEVICES'] = kwargs.get('gpus')
|
2019-10-26 21:43:19 +00:00
|
|
|
if kwargs.get('only_load_config'):
|
|
|
|
from trains_agent.backend_api.config import load
|
|
|
|
self.config = load()
|
|
|
|
else:
|
|
|
|
super(Session, self).__init__(*args, **kwargs)
|
2019-10-25 19:28:44 +00:00
|
|
|
self.log = self.get_logger(__name__)
|
|
|
|
self.trace = kwargs.get('trace', False)
|
|
|
|
self._config_file = kwargs.get('config_file') or \
|
|
|
|
os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR) or LOCAL_CONFIG_FILES[0]
|
|
|
|
self.api_client = APIClient(session=self, api_version="2.4")
|
|
|
|
# HACK make sure we have python version to execute,
|
|
|
|
# if nothing was specific, use the one that runs us
|
|
|
|
def_python = ConfigValue(self.config, "agent.default_python")
|
|
|
|
if not def_python.get():
|
|
|
|
def_python.set("{version.major}.{version.minor}".format(version=sys.version_info))
|
|
|
|
|
|
|
|
# HACK: backwards compatibility
|
|
|
|
os.environ['ALG_CONFIG_FILE'] = self._config_file
|
|
|
|
os.environ['SM_CONFIG_FILE'] = self._config_file
|
|
|
|
if not self.config.get('api.host', None) and self.config.get('api.api_server', None):
|
|
|
|
self.config['api']['host'] = self.config.get('api.api_server')
|
|
|
|
|
|
|
|
# initialize nvidia visibility variable
|
|
|
|
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
|
|
|
|
if os.environ.get('NVIDIA_VISIBLE_DEVICES') and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('NVIDIA_VISIBLE_DEVICES')
|
|
|
|
elif os.environ.get('CUDA_VISIBLE_DEVICES') and not os.environ.get('NVIDIA_VISIBLE_DEVICES'):
|
|
|
|
os.environ['NVIDIA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES')
|
|
|
|
|
|
|
|
# override with environment variables
|
|
|
|
# cuda_version & cudnn_version are overridden with os.environ here, and normalized in the next section
|
|
|
|
for config_key, env_config in ENVIRONMENT_CONFIG.items():
|
|
|
|
value = env_config.get()
|
|
|
|
if not value:
|
|
|
|
continue
|
|
|
|
env_key = ConfigValue(self.config, config_key)
|
|
|
|
env_key.set(value)
|
|
|
|
|
|
|
|
# initialize cuda versions
|
|
|
|
try:
|
|
|
|
from trains_agent.helper.package.requirements import RequirementsManager
|
|
|
|
agent = self.config['agent']
|
|
|
|
agent['cuda_version'], agent['cudnn_version'] = \
|
2019-11-08 20:36:24 +00:00
|
|
|
RequirementsManager.get_cuda_version(self.config) if not cpu_only else ('0', '0')
|
2019-10-25 19:28:44 +00:00
|
|
|
except Exception:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# initialize worker name
|
|
|
|
worker_name = ConfigValue(self.config, "agent.worker_name")
|
|
|
|
if not worker_name.get():
|
|
|
|
worker_name.set(platform.node())
|
|
|
|
|
2019-10-26 21:43:19 +00:00
|
|
|
if not kwargs.get('only_load_config'):
|
|
|
|
self.create_cache_folders()
|
2019-10-25 19:28:44 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def get_logger(name):
|
|
|
|
logger = logging.getLogger(name)
|
|
|
|
logger.propagate = True
|
|
|
|
return TrainsAgentLogger(logger)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def debug_mode(self):
|
|
|
|
return self.config.get("agent.debug", False)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def config_file(self):
|
|
|
|
return self._config_file
|
|
|
|
|
|
|
|
def create_cache_folders(self, slot_index=0):
|
|
|
|
"""
|
|
|
|
create and update the cache folders
|
|
|
|
notice we support multiple instances sharing the same cache on some folders
|
|
|
|
and on some we use "instance slot" numbers in order to differentiate between the different instances running
|
|
|
|
notice slot_index=0 is the default, meaning no suffix is added to the singleton_folders
|
|
|
|
|
|
|
|
Note: do not call this function twice with non zero slot_index
|
|
|
|
it will add a suffix to the folders on each call
|
|
|
|
|
|
|
|
:param slot_index: integer
|
|
|
|
"""
|
|
|
|
|
|
|
|
# create target folders:
|
|
|
|
folder_keys = ('agent.venvs_dir', 'agent.vcs_cache.path',
|
|
|
|
'agent.pip_download_cache.path',
|
|
|
|
'agent.docker_pip_cache', 'agent.docker_apt_cache')
|
|
|
|
singleton_folders = ('agent.venvs_dir', 'agent.vcs_cache.path',)
|
|
|
|
|
|
|
|
for key in folder_keys:
|
|
|
|
folder_key = ConfigValue(self.config, key)
|
|
|
|
if not folder_key.get():
|
|
|
|
continue
|
|
|
|
|
|
|
|
if slot_index and key in singleton_folders:
|
|
|
|
f = folder_key.get()
|
|
|
|
if f.endswith(os.path.sep):
|
|
|
|
f = f[:-1]
|
|
|
|
folder_key.set(f + '.{}'.format(slot_index))
|
|
|
|
|
|
|
|
# update the configuration for full path
|
|
|
|
folder = Path(os.path.expandvars(folder_key.get())).expanduser().absolute()
|
|
|
|
folder_key.set(folder.as_posix())
|
|
|
|
try:
|
|
|
|
folder.mkdir(parents=True, exist_ok=True)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def print_configuration(self, remove_secret_keys=("secret", "pass", "token", "account_key")):
|
|
|
|
# remove all the secrets from the print
|
|
|
|
def recursive_remove_secrets(dictionary, secret_keys=()):
|
|
|
|
for k in list(dictionary):
|
|
|
|
for s in secret_keys:
|
|
|
|
if s in k:
|
|
|
|
dictionary.pop(k)
|
|
|
|
break
|
|
|
|
if isinstance(dictionary.get(k, None), dict):
|
|
|
|
recursive_remove_secrets(dictionary[k], secret_keys=secret_keys)
|
|
|
|
elif isinstance(dictionary.get(k, None), (list, tuple)):
|
|
|
|
for item in dictionary[k]:
|
|
|
|
if isinstance(item, dict):
|
|
|
|
recursive_remove_secrets(item, secret_keys=secret_keys)
|
|
|
|
|
|
|
|
config = deepcopy(self.config.to_dict())
|
|
|
|
# remove the env variable, it's not important
|
|
|
|
config.pop('env', None)
|
|
|
|
if remove_secret_keys:
|
|
|
|
recursive_remove_secrets(config, secret_keys=remove_secret_keys)
|
|
|
|
config = ConfigFactory.from_dict(config)
|
|
|
|
self.log.debug("Run by interpreter: %s", sys.executable)
|
|
|
|
print(
|
|
|
|
"Current configuration (trains_agent v{}, location: {}):\n"
|
|
|
|
"----------------------\n{}\n".format(
|
|
|
|
self.version, self._config_file, HOCONConverter.convert(config, "properties")
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
def send_api(self, request):
|
|
|
|
# type: (Request) -> Any
|
|
|
|
result = self.send(request)
|
|
|
|
if not result.ok():
|
|
|
|
raise APIError(result)
|
|
|
|
if not result.response:
|
|
|
|
raise APIError(result, extra_info="Invalid response")
|
|
|
|
return result.response
|
|
|
|
|
|
|
|
def get(self, service, action, version=None, headers=None,
|
|
|
|
data=None, json=None, async_enable=False, **kwargs):
|
|
|
|
return self._manual_request(service=service, action=action,
|
|
|
|
version=version, method="get", headers=headers,
|
|
|
|
data=data, async_enable=async_enable,
|
|
|
|
json=json or kwargs)
|
|
|
|
|
|
|
|
def post(self, service, action, version=None, headers=None,
|
|
|
|
data=None, json=None, async_enable=False, **kwargs):
|
|
|
|
return self._manual_request(service=service, action=action,
|
|
|
|
version=version, method="post", headers=headers,
|
|
|
|
data=data, async_enable=async_enable,
|
|
|
|
json=json or kwargs)
|
|
|
|
|
|
|
|
def _manual_request(self, service, action, version=None, method="get", headers=None,
|
|
|
|
data=None, json=None, async_enable=False, **kwargs):
|
|
|
|
|
|
|
|
res = self.send_request(service=service, action=action,
|
|
|
|
version=version, method=method, headers=headers,
|
|
|
|
data=data, async_enable=async_enable,
|
|
|
|
json=json or kwargs)
|
|
|
|
|
|
|
|
try:
|
|
|
|
res_json = res.json()
|
|
|
|
return_code = res_json["meta"]["result_code"]
|
|
|
|
except (ValueError, KeyError, TypeError):
|
|
|
|
raise APIError(res)
|
|
|
|
|
|
|
|
# check return code
|
|
|
|
if return_code != 200:
|
|
|
|
raise APIError(res)
|
|
|
|
|
|
|
|
return res_json["data"]
|
|
|
|
|
|
|
|
def to_json(self):
|
|
|
|
return json.dumps(
|
|
|
|
self.config.as_plain_ordered_dict(), cls=HOCONEncoder, indent=4
|
|
|
|
)
|
|
|
|
|
|
|
|
def command(self, *args):
|
|
|
|
return Argv(*args, log=self.get_logger(Argv.__module__))
|
|
|
|
|
|
|
|
|
|
|
|
@attr.s
|
|
|
|
class TrainsAgentLogger(object):
|
|
|
|
"""
|
|
|
|
Proxy around logging.Logger because inheriting from it is difficult.
|
|
|
|
"""
|
|
|
|
|
|
|
|
logger = attr.ib(type=logging.Logger)
|
|
|
|
|
|
|
|
def _log_with_error(self, level, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Include error information when in debug mode
|
|
|
|
"""
|
|
|
|
kwargs.setdefault("exc_info", self.logger.isEnabledFor(logging.DEBUG))
|
|
|
|
return self.logger.log(level, *args, **kwargs)
|
|
|
|
|
|
|
|
def warning(self, *args, **kwargs):
|
|
|
|
return self._log_with_error(logging.WARNING, *args, **kwargs)
|
|
|
|
|
|
|
|
def error(self, *args, **kwargs):
|
|
|
|
return self._log_with_error(logging.ERROR, *args, **kwargs)
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
return getattr(self.logger, item)
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Compatibility with old ``Command.log()`` method
|
|
|
|
"""
|
|
|
|
return self.logger.info(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_cuda_version(value):
|
|
|
|
# type: (Any) -> str
|
|
|
|
"""
|
|
|
|
Take variably formatted cuda version string/number and return it in the same format:
|
|
|
|
string decimal representation of 10 * major + minor.
|
|
|
|
|
|
|
|
>>> normalize_cuda_version(100)
|
|
|
|
'100'
|
|
|
|
>>> normalize_cuda_version("100")
|
|
|
|
'100'
|
|
|
|
>>> normalize_cuda_version(10)
|
|
|
|
'10'
|
|
|
|
>>> normalize_cuda_version(10.0)
|
|
|
|
'100'
|
|
|
|
>>> normalize_cuda_version("10.0")
|
|
|
|
'100'
|
|
|
|
>>> normalize_cuda_version("10.0.130")
|
|
|
|
'100'
|
|
|
|
"""
|
|
|
|
value = str(value)
|
|
|
|
if "." in value:
|
|
|
|
try:
|
|
|
|
value = str(int(float(".".join(value.split(".")[:2])) * 10))
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
pass
|
|
|
|
return value
|