Improve argparser automagic support

This commit is contained in:
allegroai 2019-11-28 00:49:19 +02:00
parent 09257b7247
commit 7e7329f7a0
2 changed files with 34 additions and 15 deletions

View File

@ -120,7 +120,7 @@ class _Arguments(object):
for k, v in task_defaults.items(): for k, v in task_defaults.items():
try: try:
if type(v) is list: if type(v) is list:
task_defaults[k] = '[' + ', '.join(map("{0}".format, v)) + ']' task_defaults[k] = str(v)
elif type(v) not in (str, int, float, bool): elif type(v) not in (str, int, float, bool):
task_defaults[k] = str(v) task_defaults[k] = str(v)
except Exception: except Exception:
@ -154,6 +154,7 @@ class _Arguments(object):
# if k.startswith(self._prefix_args)]) # if k.startswith(self._prefix_args)])
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items() task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
if not k.startswith(self._prefix_tf_defines)]) if not k.startswith(self._prefix_tf_defines)])
arg_parser_argeuments = {}
for k, v in task_arguments.items(): for k, v in task_arguments.items():
# if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False # if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False
# with the rest we have to make sure the type is correct # with the rest we have to make sure the type is correct
@ -185,16 +186,22 @@ class _Arguments(object):
current_action.const = const_value current_action.const = const_value
except ValueError: except ValueError:
pass pass
task_arguments[k] = const_value if current_action.default is not None or const_value not in (None, ''):
elif current_action and current_action.nargs == '+': arg_parser_argeuments[k] = const_value
elif current_action and (current_action.nargs in ('+', '*') or isinstance(current_action.nargs, int)):
try: try:
v = yaml.load(v.strip(), Loader=yaml.SafeLoader) v = yaml.load(v.strip(), Loader=yaml.SafeLoader)
if current_action.type: if not isinstance(v, (list, tuple)):
# do nothing, we have no idea what happened
pass
elif current_action.type:
v = [current_action.type(a) for a in v] v = [current_action.type(a) for a in v]
elif current_action.default: elif current_action.default:
v_type = type(current_action.default[0]) v_type = type(current_action.default[0])
v = [v_type(a) for a in v] v = [v_type(a) for a in v]
task_arguments[k] = v
if current_action.default is not None or v not in (None, ''):
arg_parser_argeuments[k] = v
except Exception: except Exception:
pass pass
elif current_action and not current_action.type: elif current_action and not current_action.type:
@ -208,7 +215,14 @@ class _Arguments(object):
# now we should try and cast the value if we can # now we should try and cast the value if we can
try: try:
v = var_type(v) v = var_type(v)
task_arguments[k] = v # cast back to int if it's the same value
if type(current_action.default) == int and int(v) == v:
arg_parser_argeuments[k] = int(v)
if current_action.default is None and v in (None, ''):
# Do nothing, we should leave it as is.
pass
else:
arg_parser_argeuments[k] = v
except Exception: except Exception:
pass pass
elif current_action and current_action.type == bool: elif current_action and current_action.type == bool:
@ -225,14 +239,17 @@ class _Arguments(object):
v = int(strip_v) v = int(strip_v)
except ValueError: except ValueError:
pass pass
task_arguments[k] = v if v not in (None, ''):
arg_parser_argeuments[k] = v
# add as default # add as default
try: try:
if current_action and isinstance(current_action, _SubParsersAction): if current_action and isinstance(current_action, _SubParsersAction):
if v not in (None, '') or current_action.default not in (None, ''):
current_action.default = v current_action.default = v
current_action.required = False current_action.required = False
elif current_action and isinstance(current_action, _StoreAction): elif current_action and isinstance(current_action, _StoreAction):
if v not in (None, '') or current_action.default not in (None, ''):
current_action.default = v current_action.default = v
current_action.required = False current_action.required = False
# python2 doesn't support defaults for positional arguments, unless used with nargs=? # python2 doesn't support defaults for positional arguments, unless used with nargs=?
@ -253,11 +270,13 @@ class _Arguments(object):
pass pass
except Exception: except Exception:
pass pass
# if we already have an instance of parsed args, we should update its values # if we already have an instance of parsed args, we should update its values
if parsed_args: if parsed_args:
for k, v in task_arguments.items(): for k, v in arg_parser_argeuments.items():
if parsed_args.get(k) is not None or v not in (None, ''):
setattr(parsed_args, k, v) setattr(parsed_args, k, v)
parser.set_defaults(**task_arguments) parser.set_defaults(**arg_parser_argeuments)
def copy_from_dict(self, dictionary, prefix=None): def copy_from_dict(self, dictionary, prefix=None):
# TODO: add dict prefix # TODO: add dict prefix

View File

@ -10,7 +10,7 @@ class ProxyDictPostWrite(dict):
self._update_func = None self._update_func = None
for k, i in self.items(): for k, i in self.items():
if isinstance(i, dict): if isinstance(i, dict):
super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)}) super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, i)})
self._update_func = update_func self._update_func = update_func
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -32,7 +32,7 @@ class ProxyDictPostWrite(dict):
def update(self, E=None, **F): def update(self, E=None, **F):
return super(ProxyDictPostWrite, self).update( return super(ProxyDictPostWrite, self).update(
ProxyDictPostWrite(self._update_obj, self._set_callback, **E) if E is not None else ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
ProxyDictPostWrite(self._update_obj, self._set_callback, **F)) ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
@ -44,7 +44,7 @@ class ProxyDictPreWrite(dict):
self._update_func = None self._update_func = None
for k, i in self.items(): for k, i in self.items():
if isinstance(i, dict): if isinstance(i, dict):
self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)}) self.update({k: ProxyDictPreWrite(k, self._nested_callback, i)})
self._update_obj = update_obj self._update_obj = update_obj
self._update_func = update_func self._update_func = update_func