From 7e7329f7a0d758a0b5cc8a5c6859184df741c2fd Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 28 Nov 2019 00:49:19 +0200 Subject: [PATCH] Improve argparser automagic support --- trains/backend_interface/task/args.py | 43 +++++++++++++++++++-------- trains/utilities/proxy_object.py | 6 ++-- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 2fdbdca2..b5167880 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -120,7 +120,7 @@ class _Arguments(object): for k, v in task_defaults.items(): try: if type(v) is list: - task_defaults[k] = '[' + ', '.join(map("{0}".format, v)) + ']' + task_defaults[k] = str(v) elif type(v) not in (str, int, float, bool): task_defaults[k] = str(v) except Exception: @@ -154,6 +154,7 @@ class _Arguments(object): # if k.startswith(self._prefix_args)]) task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(self._prefix_tf_defines)]) + arg_parser_argeuments = {} for k, v in task_arguments.items(): # if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False # with the rest we have to make sure the type is correct @@ -185,16 +186,22 @@ class _Arguments(object): current_action.const = const_value except ValueError: pass - task_arguments[k] = const_value - elif current_action and current_action.nargs == '+': + if current_action.default is not None or const_value not in (None, ''): + arg_parser_argeuments[k] = const_value + elif current_action and (current_action.nargs in ('+', '*') or isinstance(current_action.nargs, int)): try: v = yaml.load(v.strip(), Loader=yaml.SafeLoader) - if current_action.type: + if not isinstance(v, (list, tuple)): + # do nothing, we have no idea what happened + pass + elif current_action.type: v = [current_action.type(a) for a in v] elif current_action.default: v_type = type(current_action.default[0]) v = [v_type(a) for a in v] - task_arguments[k] = v + + if current_action.default is not None or v not in (None, ''): + arg_parser_argeuments[k] = v except Exception: pass elif current_action and not current_action.type: @@ -208,7 +215,14 @@ class _Arguments(object): # now we should try and cast the value if we can try: v = var_type(v) - task_arguments[k] = v + # cast back to int if it's the same value + if type(current_action.default) == int and int(v) == v: + arg_parser_argeuments[k] = int(v) + if current_action.default is None and v in (None, ''): + # Do nothing, we should leave it as is. + pass + else: + arg_parser_argeuments[k] = v except Exception: pass elif current_action and current_action.type == bool: @@ -225,15 +239,18 @@ class _Arguments(object): v = int(strip_v) except ValueError: pass - task_arguments[k] = v + if v not in (None, ''): + arg_parser_argeuments[k] = v # add as default try: if current_action and isinstance(current_action, _SubParsersAction): - current_action.default = v + if v not in (None, '') or current_action.default not in (None, ''): + current_action.default = v current_action.required = False elif current_action and isinstance(current_action, _StoreAction): - current_action.default = v + if v not in (None, '') or current_action.default not in (None, ''): + current_action.default = v current_action.required = False # python2 doesn't support defaults for positional arguments, unless used with nargs=? if PY2 and not current_action.nargs: @@ -253,11 +270,13 @@ class _Arguments(object): pass except Exception: pass + # if we already have an instance of parsed args, we should update its values if parsed_args: - for k, v in task_arguments.items(): - setattr(parsed_args, k, v) - parser.set_defaults(**task_arguments) + for k, v in arg_parser_argeuments.items(): + if parsed_args.get(k) is not None or v not in (None, ''): + setattr(parsed_args, k, v) + parser.set_defaults(**arg_parser_argeuments) def copy_from_dict(self, dictionary, prefix=None): # TODO: add dict prefix diff --git a/trains/utilities/proxy_object.py b/trains/utilities/proxy_object.py index 928c7809..bb63d6d2 100644 --- a/trains/utilities/proxy_object.py +++ b/trains/utilities/proxy_object.py @@ -10,7 +10,7 @@ class ProxyDictPostWrite(dict): self._update_func = None for k, i in self.items(): if isinstance(i, dict): - super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)}) + super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, i)}) self._update_func = update_func def __setitem__(self, key, value): @@ -32,7 +32,7 @@ class ProxyDictPostWrite(dict): def update(self, E=None, **F): return super(ProxyDictPostWrite, self).update( - ProxyDictPostWrite(self._update_obj, self._set_callback, **E) if E is not None else + ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else ProxyDictPostWrite(self._update_obj, self._set_callback, **F)) @@ -44,7 +44,7 @@ class ProxyDictPreWrite(dict): self._update_func = None for k, i in self.items(): if isinstance(i, dict): - self.update({k: ProxyDictPreWrite(k, self._nested_callback, **i)}) + self.update({k: ProxyDictPreWrite(k, self._nested_callback, i)}) self._update_obj = update_obj self._update_func = update_func