mirror of
https://github.com/clearml/clearml-agent
synced 2025-06-26 18:16:15 +00:00
Rename trains-agent -> clearml-agent
This commit is contained in:
596
clearml_agent/helper/base.py
Normal file
596
clearml_agent/helper/base.py
Normal file
@@ -0,0 +1,596 @@
|
||||
""" TRAINS-AGENT Stdout Helper Functions """
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from abc import ABCMeta
|
||||
from collections import OrderedDict
|
||||
from distutils.spawn import find_executable
|
||||
from functools import total_ordering
|
||||
from typing import Text, Dict, Any, Optional, AnyStr, IO, Union
|
||||
|
||||
import attr
|
||||
import furl
|
||||
import pyhocon
|
||||
import yaml
|
||||
from attr import fields_dict
|
||||
from pathlib2 import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
import six
|
||||
from six.moves import reduce
|
||||
from trains_agent.errors import CommandFailedError
|
||||
from trains_agent.helper.dicts import filter_keys
|
||||
|
||||
pretty_lines = False
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def which(cmd, path=None):
|
||||
result = find_executable(cmd, path)
|
||||
if not result:
|
||||
raise ValueError('command "{}" not found'.format(cmd))
|
||||
return result
|
||||
|
||||
|
||||
def select_for_platform(linux, windows):
|
||||
"""
|
||||
Select between multiple values according to the OS
|
||||
:param linux: value to return if OS is linux
|
||||
:param windows: value to return if OS is Windows
|
||||
"""
|
||||
return windows if is_windows_platform() else linux
|
||||
|
||||
|
||||
def bash_c():
|
||||
return 'bash -c' if not is_windows_platform() else 'cmd /c'
|
||||
|
||||
|
||||
def return_list(arg):
|
||||
if arg and not isinstance(arg, (tuple, list)):
|
||||
return [arg]
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def print_table(entries, columns=(), titles=(), csv=None, headers=True):
|
||||
table = create_table(entries, columns=columns, titles=titles, csv=csv, headers=headers)
|
||||
if csv:
|
||||
with open(csv, 'w') as output:
|
||||
print(table, file=output)
|
||||
else:
|
||||
print(table)
|
||||
|
||||
|
||||
def create_table(entries, columns=(), titles=(), csv=None, headers=True):
|
||||
table = [
|
||||
[
|
||||
reduce(
|
||||
lambda obj, key: obj.get(key, {}),
|
||||
column.split('.'),
|
||||
entry
|
||||
) or ''
|
||||
for column in columns
|
||||
]
|
||||
for entry in entries
|
||||
]
|
||||
if headers:
|
||||
headers = [titles[i] if i < len(titles) and titles[i] else c for i, c in enumerate(columns)]
|
||||
else:
|
||||
headers = []
|
||||
output = ''
|
||||
if csv:
|
||||
if headers:
|
||||
output += ','.join(headers) + '\n'
|
||||
for entry in table:
|
||||
output += ','.join(map(str, entry)) + '\n'
|
||||
else:
|
||||
min_col_width = 3
|
||||
col_widths = [max(min_col_width, len(h)+1) for h in (headers or table[0])]
|
||||
for e in table:
|
||||
col_widths = list(map(max, zip(col_widths, [len(h)+1 for h in e])))
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
if headers:
|
||||
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*headers) + ' |' + '\n'
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
|
||||
for entry in table:
|
||||
line = map(str, entry)
|
||||
output += '| ' + '| '.join(['{: <%d}' % c for c in col_widths]).format(*line) + ' |' + '\n'
|
||||
|
||||
output += '+-' + '+-'.join(['-' * c for c in col_widths]) + '-+' + '\n'
|
||||
return output
|
||||
|
||||
|
||||
def create_tree(entries, id='id', parent='parent', node_title='%(id)'):
|
||||
tree = OrderedDict()
|
||||
all_nodes = dict()
|
||||
for t in entries:
|
||||
i = t.get(id, None)
|
||||
p = t.get(parent, None)
|
||||
if not p and i not in tree:
|
||||
# push roots
|
||||
myd = all_nodes.get(i, OrderedDict())
|
||||
# add node title
|
||||
tree[node_title % t] = myd
|
||||
all_nodes[i] = myd
|
||||
elif p:
|
||||
# update parent dictionary
|
||||
d = all_nodes.get(p, OrderedDict())
|
||||
# get node dictionary
|
||||
myd = all_nodes.get(i, OrderedDict())
|
||||
# add node title
|
||||
d[node_title % t] = myd
|
||||
all_nodes[p] = d
|
||||
all_nodes[i] = myd
|
||||
else:
|
||||
pass
|
||||
return {'': tree}
|
||||
|
||||
|
||||
def print_parameters(param_struct, indent=1):
|
||||
text = yaml.safe_dump(param_struct, allow_unicode=True, indent=indent, default_flow_style=False)
|
||||
print(text)
|
||||
|
||||
|
||||
def get_list_files(basefolder, filext=('.jpg')):
|
||||
filext = [e.lower() for e in filext]
|
||||
fileiter = (os.path.join(root, f)
|
||||
for root, _, files in os.walk(basefolder)
|
||||
for f in files if os.path.splitext(f)[1].lower() in filext)
|
||||
return fileiter
|
||||
|
||||
|
||||
def is_windows_platform():
|
||||
return any(platform.win32_ver())
|
||||
|
||||
|
||||
def is_linux_platform():
|
||||
return 'linux' in platform.system().lower()
|
||||
|
||||
|
||||
def normalize_path(*paths):
|
||||
"""
|
||||
normalize_path
|
||||
|
||||
Joins ``*paths``, expands ``~`` and normalizes path separators.
|
||||
|
||||
:param paths: path components to create path from
|
||||
"""
|
||||
return os.path.normpath(os.path.expandvars(os.path.expanduser(os.path.join(*map(str, paths)))))
|
||||
|
||||
|
||||
def safe_remove_file(filename, error_message=None):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if filename:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
if error_message:
|
||||
print(error_message)
|
||||
|
||||
|
||||
def safe_remove_tree(filename):
|
||||
if not filename:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
shutil.rmtree(filename, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(filename)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_python_path(script_dir, entry_point, package_api, is_conda_env=False):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
python_path_cmd = package_api.get_python_command(
|
||||
["-c", "import sys; print('{}'.join(sys.path))".format(python_path_sep)])
|
||||
org_python_path = python_path_cmd.get_output(cwd=script_dir)
|
||||
# Add path of the script directory and executable directory
|
||||
python_path = '{}{python_path_sep}{}{python_path_sep}'.format(
|
||||
Path(script_dir).absolute().as_posix(),
|
||||
(Path(script_dir) / Path(entry_point)).parent.absolute().as_posix(),
|
||||
python_path_sep=python_path_sep)
|
||||
if is_windows_platform():
|
||||
python_path = python_path.replace('/', '\\')
|
||||
|
||||
return python_path if is_conda_env else (python_path + org_python_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def add_python_path(base_path, extra_path):
|
||||
try:
|
||||
if not extra_path:
|
||||
return base_path
|
||||
python_path_sep = ';' if is_windows_platform() else ':'
|
||||
base_path = base_path or ''
|
||||
if not base_path.endswith(python_path_sep):
|
||||
base_path += python_path_sep
|
||||
base_path += extra_path.replace(':', python_path_sep)
|
||||
except:
|
||||
pass
|
||||
return base_path
|
||||
|
||||
|
||||
class Singleton(ABCMeta):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class CompareAnything(object):
|
||||
"""
|
||||
CompareAnything
|
||||
|
||||
Creates an object which is always the smallest when compared to other objects.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __eq__(_):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __lt__(_):
|
||||
return True
|
||||
|
||||
|
||||
def nonstrict_in_place_sort(lst, reverse, *keys):
|
||||
"""
|
||||
nonstrict_in_place_sort
|
||||
|
||||
Sorts a list of dictionaries in-place by ``keys``.
|
||||
An element without a certain ``key`` will be considered the smallest in respect to that key.
|
||||
|
||||
:param lst: list to sort
|
||||
:type lst: ``[dict]``
|
||||
:param reverse: whether to reverse sorting
|
||||
:type reverse: ``bool``
|
||||
:param keys: Keys to sort by.
|
||||
Elements will be sorted pseudo-lexicographically by the values corresponding to ``*keys``, i.e:
|
||||
the list will be first sorted by the first element of ``*keys``,
|
||||
elements which are equal by the first sort will be internally sorted by
|
||||
the second element of ``*keys`` and so on.
|
||||
:type keys: ``[str]``
|
||||
"""
|
||||
lst.sort(
|
||||
key=lambda item: tuple(item.get(key, CompareAnything()) for key in keys),
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
|
||||
def load_yaml(path):
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
try:
|
||||
with open(path) as data_file:
|
||||
return yaml.safe_load(data_file) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError('Failed parsing yaml file [{}]: {}'.format(path, e))
|
||||
|
||||
|
||||
def dump_yaml(obj, path=None, dump_all=False, **kwargs):
|
||||
base_kwargs = dict(indent=4, allow_unicode=True, default_flow_style=False)
|
||||
base_kwargs.update(kwargs)
|
||||
if dump_all:
|
||||
base_kwargs['Dumper'] = AllDumper
|
||||
dump_func = yaml.dump
|
||||
else:
|
||||
dump_func = yaml.safe_dump
|
||||
if not path:
|
||||
return dump_func(obj, **base_kwargs)
|
||||
path = str(path)
|
||||
with open(path, 'w') as output:
|
||||
dump_func(obj, output, **base_kwargs)
|
||||
|
||||
|
||||
def one_value(dct):
|
||||
return next(iter(six.itervalues(dct)))
|
||||
|
||||
|
||||
@attr.s
|
||||
class RepoInfo(object):
|
||||
type = attr.ib(type=str)
|
||||
url = attr.ib(type=str)
|
||||
branch = attr.ib(type=str)
|
||||
commit = attr.ib(type=str)
|
||||
root = attr.ib(type=str)
|
||||
|
||||
|
||||
def get_repo_info(repo_type, path):
|
||||
assert repo_type in ['git', 'hg']
|
||||
if repo_type == 'git':
|
||||
commands = dict(
|
||||
url='git remote get-url origin',
|
||||
branch='git rev-parse --abbrev-ref HEAD',
|
||||
commit='git rev-parse HEAD',
|
||||
root='git rev-parse --show-toplevel'
|
||||
)
|
||||
elif repo_type == 'hg':
|
||||
commands = dict(
|
||||
url='hg paths --verbose',
|
||||
branch='hg --debug id -b',
|
||||
commit='hg --debug id -i',
|
||||
root='hg root'
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown repository type '{}'".format(repo_type))
|
||||
|
||||
commands_result = {
|
||||
name: subprocess.check_output(command.split(), cwd=path).decode().strip()
|
||||
for name, command in commands.items()
|
||||
}
|
||||
return RepoInfo(type=repo_type, **commands_result)
|
||||
|
||||
|
||||
def reverse_home_folder_expansion(path):
|
||||
path = str(path)
|
||||
if is_windows_platform():
|
||||
return path
|
||||
return re.sub('^{}/'.format(re.escape(str(Path.home()))), '~/', path)
|
||||
|
||||
|
||||
def represent_ordered_dict(dumper, data):
|
||||
"""
|
||||
Serializes ``OrderedDict`` to YAML by its proper order.
|
||||
Registering this function to ``yaml.SafeDumper`` enables using ``yaml.safe_dump`` with ``OrderedDict``s.
|
||||
"""
|
||||
return dumper.represent_mapping(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, data.items())
|
||||
|
||||
|
||||
def construct_mapping(loader, node):
|
||||
"""
|
||||
Deserialize YAML mappings as ``OrderedDict``s.
|
||||
"""
|
||||
loader.flatten_mapping(node)
|
||||
return OrderedDict(loader.construct_pairs(node))
|
||||
|
||||
|
||||
yaml.SafeDumper.add_representer(OrderedDict, represent_ordered_dict)
|
||||
yaml.SafeLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
|
||||
|
||||
|
||||
class AllDumper(yaml.SafeDumper):
|
||||
pass
|
||||
|
||||
|
||||
AllDumper.add_multi_representer(object, lambda dumper, data: dumper.represent_str(str(data)))
|
||||
|
||||
|
||||
def error(message):
|
||||
print('\ntrains_agent: ERROR: {}\n'.format(message))
|
||||
|
||||
|
||||
def warning(message):
|
||||
print('trains_agent: Warning: {}'.format(message))
|
||||
|
||||
|
||||
class TqdmStream(object):
|
||||
|
||||
def __init__(self, file_object):
|
||||
self.buffer = file_object
|
||||
|
||||
def write(self, data):
|
||||
self.buffer.write(data.strip())
|
||||
|
||||
def flush(self):
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
class TqdmLog(tqdm):
|
||||
|
||||
def __init__(self, iterable=None, file=None, **kwargs):
|
||||
super(TqdmLog, self).__init__(iterable, file=TqdmStream(file or sys.stderr), **kwargs)
|
||||
|
||||
|
||||
def url_join(first, *rest):
|
||||
"""
|
||||
Join url parts similarly to Path.join
|
||||
"""
|
||||
return str(furl.furl(first).path.add(rest)).lstrip('/')
|
||||
|
||||
|
||||
class LowercaseFormatter(logging.Formatter):
|
||||
def format(self, record, *args, **kwargs):
|
||||
record.levelname = record.levelname.lower()
|
||||
return super(LowercaseFormatter, self).format(record, *args, **kwargs)
|
||||
|
||||
|
||||
def mkstemp(
|
||||
open_kwargs=None, # type: Optional[Dict[Text, Any]]
|
||||
text=True, # type: bool
|
||||
name_only=False, # type: bool
|
||||
*args,
|
||||
**kwargs):
|
||||
# type: (...) -> Union[(IO[AnyStr], Text), Text]
|
||||
"""
|
||||
WARNING: the returned file object is strict about its input type,
|
||||
make sure to feed it binary/text input in correspondence to the ``text`` argument
|
||||
:param open_kwargs: keyword arguments for ``io.open``
|
||||
:param text: open in text mode
|
||||
:param name_only: close the file and return its name
|
||||
:param args: tempfile.mkstemp args
|
||||
:param kwargs: tempfile.mkstemp kwargs
|
||||
"""
|
||||
fd, name = tempfile.mkstemp(text=text, *args, **kwargs)
|
||||
mode = 'w+'
|
||||
if not text:
|
||||
mode += 'b'
|
||||
if name_only:
|
||||
os.close(fd)
|
||||
return name
|
||||
return io.open(fd, mode, **open_kwargs or {}), name
|
||||
|
||||
|
||||
def named_temporary_file(*args, **kwargs):
|
||||
if six.PY2:
|
||||
buffering = kwargs.pop('buffering', None)
|
||||
if buffering:
|
||||
kwargs['bufsize'] = buffering
|
||||
return tempfile.NamedTemporaryFile(*args, **kwargs)
|
||||
|
||||
|
||||
def parse_override(string):
|
||||
return pyhocon.ConfigFactory.parse_string(string).as_plain_ordered_dict()
|
||||
|
||||
|
||||
def chain_map(*args):
|
||||
return reduce(lambda x, y: x.update(y) or x, args, {})
|
||||
|
||||
|
||||
def check_directory_path(path, check_whitespace_in_path=True):
|
||||
message = 'Could not create directory "{}": {}'
|
||||
if not is_windows_platform() and check_whitespace_in_path:
|
||||
match = re.search(r'\s', path)
|
||||
if match:
|
||||
raise CommandFailedError(
|
||||
'directories may not contain whitespace (char: {!r}, position: {})'.format(match.group(0),
|
||||
match.endpos))
|
||||
try:
|
||||
Path(os.path.expandvars(path)).expanduser().mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
raise CommandFailedError(message.format(path, e.strerror))
|
||||
except Exception as e:
|
||||
raise CommandFailedError(message.format(path, e))
|
||||
|
||||
|
||||
def create_file_if_not_exists(path):
|
||||
if not os.path.exists(os.path.expanduser(os.path.expandvars(path))):
|
||||
open(path, "w").close()
|
||||
|
||||
|
||||
def rm_tree(root): # type: (Union[Path, Text]) -> None
|
||||
"""
|
||||
A version of shutil.rmtree that handles access errors, specifically hidden files on Windows
|
||||
"""
|
||||
def on_error(func, path, _):
|
||||
try:
|
||||
if os.path.exists(path) and not os.access(path, os.W_OK):
|
||||
os.chmod(path, stat.S_IWUSR)
|
||||
func(path)
|
||||
except Exception:
|
||||
pass
|
||||
return shutil.rmtree(os.path.expanduser(os.path.expandvars(Text(root))), onerror=on_error)
|
||||
|
||||
|
||||
def rm_file(filename): # type: (Union[Path, Text]) -> None
|
||||
"""
|
||||
A version of os.unlink that will not raise error
|
||||
"""
|
||||
try:
|
||||
os.unlink(os.path.expanduser(os.path.expandvars(Text(filename))))
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_conda(config):
|
||||
return config['agent.package_manager.type'].lower() == 'conda'
|
||||
|
||||
|
||||
class NonStrictAttrs(object):
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs):
|
||||
fields = fields_dict(cls)
|
||||
return cls(**filter_keys(lambda key: key in fields, kwargs))
|
||||
|
||||
|
||||
def python_version_string():
|
||||
return '{v.major}.{v.minor}'.format(v=sys.version_info)
|
||||
|
||||
|
||||
join_lines = '\n'.join
|
||||
|
||||
|
||||
class HOCONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
pyhocon bugs:
|
||||
1. "\\t" is dumped as "\t" instead of "\\t", which is read as the character "\t".
|
||||
2. parsed config trees have dummy `pyhocon.config_tree.NoneValue` in them.
|
||||
(see: https://github.com/chimpler/pyhocon/issues/111)
|
||||
Workaround: dump HOCON to JSON, of which it is a subset, taking care of `NoneValue`s.
|
||||
"""
|
||||
|
||||
def default(self, o):
|
||||
"""
|
||||
If o is `pyhocon.config_tree.NoneValue`, encode it the same way as `None`.
|
||||
"""
|
||||
if isinstance(o, pyhocon.config_tree.NoneValue):
|
||||
return super(HOCONEncoder, self).encode(None)
|
||||
return super(HOCONEncoder, self).default(o)
|
||||
|
||||
|
||||
nullable_string = attr.ib(default="", converter=lambda x: x.strip())
|
||||
normal_path = attr.ib(default="", converter=lambda p: p and normalize_path(p))
|
||||
|
||||
|
||||
@attr.s
|
||||
class ExecutionInfo(NonStrictAttrs):
|
||||
repository = nullable_string
|
||||
entry_point = normal_path
|
||||
working_dir = normal_path
|
||||
branch = nullable_string
|
||||
version_num = nullable_string
|
||||
tag = nullable_string
|
||||
docker_cmd = nullable_string
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task_info):
|
||||
# type: (...) -> ExecutionInfo
|
||||
"""
|
||||
extract ExecutionInfo tuple from task parameters
|
||||
"""
|
||||
if not task_info.script:
|
||||
raise CommandFailedError("can not run task without script information")
|
||||
execution = cls.from_dict(task_info.script.to_dict())
|
||||
if not execution.entry_point:
|
||||
log.warning("notice: `script.entry_point` is empty")
|
||||
if not execution.working_dir:
|
||||
entry_point, _, working_dir = execution.entry_point.partition(":")
|
||||
execution.entry_point = entry_point
|
||||
execution.working_dir = working_dir or ""
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
execution.docker_cmd = task_info.execution.docker_cmd
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return execution
|
||||
|
||||
|
||||
class safe_furl(furl.furl):
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
|
||||
@port.setter
|
||||
def port(self, port):
|
||||
"""
|
||||
Any port value is valid
|
||||
"""
|
||||
self._port = port
|
||||
Reference in New Issue
Block a user