mirror of
https://github.com/clearml/clearml
synced 2025-02-01 01:26:49 +00:00
323 lines
14 KiB
Python
323 lines
14 KiB
Python
import yaml
|
|
|
|
from six import PY2
|
|
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
|
|
from copy import copy
|
|
|
|
from ...utilities.args import call_original_argparser
|
|
|
|
|
|
class _Arguments(object):
|
|
_prefix_sep = '/'
|
|
# TODO: separate dict and argparse after we add UI support
|
|
_prefix_dict = 'dict' + _prefix_sep
|
|
_prefix_args = 'argparse' + _prefix_sep
|
|
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
|
|
|
|
class _ProxyDictWrite(dict):
|
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
|
def __init__(self, arguments, *args, **kwargs):
|
|
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
|
|
self._arguments = arguments
|
|
|
|
def __setitem__(self, key, value):
|
|
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
|
|
if self._arguments:
|
|
self._arguments.copy_from_dict(self)
|
|
|
|
class _ProxyDictReadOnly(dict):
|
|
""" Dictionary wrapper that prevents modifications to the dictionary """
|
|
def __init__(self, arguments, *args, **kwargs):
|
|
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
|
|
self._arguments = arguments
|
|
|
|
def __setitem__(self, key, value):
|
|
if self._arguments:
|
|
param_dict = self._arguments.copy_to_dict({key: value})
|
|
value = param_dict.get(key, value)
|
|
super(_Arguments._ProxyDictReadOnly, self).__setitem__(key, value)
|
|
|
|
def __init__(self, task):
|
|
super(_Arguments, self).__init__()
|
|
self._task = task
|
|
|
|
def set_defaults(self, *dicts, **kwargs):
|
|
self._task.set_parameters(*dicts, **kwargs)
|
|
|
|
def add_argument(self, option_strings, type=None, default=None, help=None):
|
|
if not option_strings:
|
|
raise Exception('Expected at least one argument name (option string)')
|
|
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
|
# TODO: add argparse prefix
|
|
# name = self._prefix_args + name
|
|
self._task.set_parameter(name=name, value=default, description=help)
|
|
|
|
def connect(self, parser):
|
|
self._task.connect_argparse(parser)
|
|
|
|
@classmethod
|
|
def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None):
|
|
actions = [
|
|
a for a in a_parser._actions
|
|
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
|
|
]
|
|
args_dict = {}
|
|
try:
|
|
if isinstance(a_parsed_args, dict):
|
|
args_dict = a_parsed_args
|
|
else:
|
|
if a_parsed_args:
|
|
args_dict = a_parsed_args.__dict__
|
|
else:
|
|
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
|
|
defaults_ = {
|
|
a.dest: args_dict.get(a.dest) if (args_dict.get(a.dest) is not None) else ''
|
|
for a in actions
|
|
}
|
|
except Exception:
|
|
# don't crash us if we failed parsing the inputs
|
|
defaults_ = {
|
|
a.dest: a.default if a.default is not None else ''
|
|
for a in actions
|
|
}
|
|
|
|
full_args_dict = copy(defaults)
|
|
full_args_dict.update(args_dict)
|
|
defaults.update(defaults_)
|
|
|
|
# deal with sub parsers
|
|
sub_parsers = [
|
|
a for a in a_parser._actions
|
|
if isinstance(a, _SubParsersAction)
|
|
]
|
|
for sub_parser in sub_parsers:
|
|
if sub_parser.dest and sub_parser.dest != SUPPRESS:
|
|
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest)
|
|
for choice in sub_parser.choices.values():
|
|
# recursively parse
|
|
defaults = cls._add_to_defaults(
|
|
a_parser=choice,
|
|
defaults=defaults,
|
|
a_parsed_args=a_parsed_args or full_args_dict
|
|
)
|
|
|
|
return defaults
|
|
|
|
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
|
task_defaults = {}
|
|
self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args)
|
|
|
|
# Make sure we didn't miss anything
|
|
if parsed_args:
|
|
for k, v in parsed_args.__dict__.items():
|
|
if k not in task_defaults:
|
|
if type(v) == None:
|
|
task_defaults[k] = ''
|
|
elif type(v) in (str, int, float, bool, list):
|
|
task_defaults[k] = v
|
|
|
|
# Verify arguments
|
|
for k, v in task_defaults.items():
|
|
try:
|
|
if type(v) is list:
|
|
task_defaults[k] = '[' + ', '.join(map("{0}".format, v)) + ']'
|
|
elif type(v) not in (str, int, float, bool):
|
|
task_defaults[k] = str(v)
|
|
except Exception:
|
|
del task_defaults[k]
|
|
# Add prefix, TODO: add argparse prefix
|
|
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()])
|
|
task_defaults = dict([(k, v) for k, v in task_defaults.items()])
|
|
# Store to task
|
|
self._task.update_parameters(task_defaults)
|
|
|
|
@classmethod
|
|
def _find_parser_action(cls, a_parser, name):
|
|
# find by name
|
|
_actions = [(a_parser, a) for a in a_parser._actions if a.dest == name]
|
|
if _actions:
|
|
return _actions
|
|
# iterate over subparsers
|
|
_actions = []
|
|
sub_parsers = [a for a in a_parser._actions if isinstance(a, _SubParsersAction)]
|
|
for sub_parser in sub_parsers:
|
|
for choice in sub_parser.choices.values():
|
|
# recursively parse
|
|
_action = cls._find_parser_action(choice, name)
|
|
if _action:
|
|
_actions.extend(_action)
|
|
return _actions
|
|
|
|
def copy_to_parser(self, parser, parsed_args):
|
|
# todo: change to argparse prefix only
|
|
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
|
|
# if k.startswith(self._prefix_args)])
|
|
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
|
|
if not k.startswith(self._prefix_tf_defines)])
|
|
for k, v in task_arguments.items():
|
|
# if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False
|
|
# with the rest we have to make sure the type is correct
|
|
matched_actions = self._find_parser_action(parser, k)
|
|
for parent_parser, current_action in matched_actions:
|
|
if current_action and isinstance(current_action, _StoreConstAction):
|
|
# make the default value boolean
|
|
# first check if False value
|
|
const_value = current_action.const if current_action.const is not None else (
|
|
current_action.default if current_action.default is not None else True)
|
|
const_type = type(const_value)
|
|
strip_v = str(v).lower().strip()
|
|
if const_type == bool:
|
|
if strip_v == 'false' or not strip_v:
|
|
const_value = False
|
|
elif strip_v == 'true':
|
|
const_value = True
|
|
else:
|
|
# first try to cast to integer
|
|
try:
|
|
const_value = int(strip_v)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
const_value = strip_v
|
|
# then cast to const type (might be boolean)
|
|
try:
|
|
const_value = const_type(const_value)
|
|
current_action.const = const_value
|
|
except ValueError:
|
|
pass
|
|
task_arguments[k] = const_value
|
|
elif current_action and current_action.nargs == '+':
|
|
try:
|
|
v = yaml.load(v.strip())
|
|
if current_action.type:
|
|
v = [current_action.type(a) for a in v]
|
|
elif current_action.default:
|
|
v_type = type(current_action.default[0])
|
|
v = [v_type(a) for a in v]
|
|
task_arguments[k] = v
|
|
except Exception:
|
|
pass
|
|
elif current_action and not current_action.type:
|
|
# cast manually if there is no type
|
|
var_type = type(current_action.default)
|
|
# if we have an int, we should cast to float, because it is more generic
|
|
if var_type == int:
|
|
var_type = float
|
|
elif var_type == type(None):
|
|
var_type = str
|
|
# now we should try and cast the value if we can
|
|
try:
|
|
v = var_type(v)
|
|
task_arguments[k] = v
|
|
except Exception:
|
|
pass
|
|
# add as default
|
|
try:
|
|
if current_action and isinstance(current_action, _SubParsersAction):
|
|
current_action.default = v
|
|
current_action.required = False
|
|
elif current_action and isinstance(current_action, _StoreAction):
|
|
current_action.default = v
|
|
current_action.required = False
|
|
# python2 doesn't support defaults for positional arguments, unless used with nargs=?
|
|
if PY2 and not current_action.nargs:
|
|
current_action.nargs = '?'
|
|
else:
|
|
# do not add parameters that do not exist in argparser, they might be the dict
|
|
# # add new parameters to arg parser
|
|
# parent_parser.add_argument(
|
|
# '--%s' % k,
|
|
# default=v,
|
|
# type=type(v),
|
|
# required=False,
|
|
# help='Task parameter %s (default %s)' % (k, v),
|
|
# )
|
|
pass
|
|
except ArgumentError:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
# if we already have an instance of parsed args, we should update its values
|
|
if parsed_args:
|
|
for k, v in task_arguments.items():
|
|
setattr(parsed_args, k, v)
|
|
parser.set_defaults(**task_arguments)
|
|
|
|
def copy_from_dict(self, dictionary, prefix=None):
|
|
# TODO: add dict prefix
|
|
prefix = prefix or '' # self._prefix_dict
|
|
if prefix:
|
|
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
|
|
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
|
|
cur_params.update(prefix_dictionary)
|
|
self._task.set_parameters(cur_params)
|
|
else:
|
|
self._task.update_parameters(dictionary)
|
|
if not isinstance(dictionary, self._ProxyDictWrite):
|
|
return self._ProxyDictWrite(self, **dictionary)
|
|
return dictionary
|
|
|
|
def copy_to_dict(self, dictionary, prefix=None):
|
|
# iterate over keys and merge values according to parameter type in dictionary
|
|
# TODO: add dict prefix
|
|
prefix = prefix or '' # self._prefix_dict
|
|
if prefix:
|
|
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
|
if k.startswith(prefix)])
|
|
else:
|
|
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
|
|
if not k.startswith(self._prefix_tf_defines)])
|
|
|
|
for k, v in dictionary.items():
|
|
param = parameters.get(k, None)
|
|
if param is None:
|
|
continue
|
|
v_type = type(v)
|
|
# assume more general purpose type int -> float
|
|
if v_type == int:
|
|
if int(v) != float(v):
|
|
v_type = float
|
|
elif v_type == bool:
|
|
# cast based on string or int
|
|
try:
|
|
param = bool(float(param))
|
|
except ValueError:
|
|
try:
|
|
param = str(param).lower().strip() == 'true'
|
|
except ValueError:
|
|
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
|
(str(k), str(param), str(k), str(v)))
|
|
continue
|
|
elif v_type == list:
|
|
try:
|
|
p = str(param).strip()
|
|
param = yaml.load(p)
|
|
except Exception:
|
|
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
|
(str(k), str(param), str(k), str(v)))
|
|
continue
|
|
elif v_type == dict:
|
|
try:
|
|
p = str(param).strip()
|
|
param = yaml.load(p)
|
|
except Exception:
|
|
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
|
(str(k), str(param), str(k), str(v)))
|
|
elif v_type == type(None):
|
|
v_type = str
|
|
|
|
try:
|
|
dictionary[k] = v_type(param)
|
|
except ValueError:
|
|
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
|
(str(k), str(param), str(k), str(v)))
|
|
continue
|
|
# add missing parameters to dictionary
|
|
# for k, v in parameters.items():
|
|
# if k not in dictionary:
|
|
# dictionary[k] = v
|
|
|
|
if not isinstance(dictionary, self._ProxyDictReadOnly):
|
|
return self._ProxyDictReadOnly(self, **dictionary)
|
|
return dictionary
|