From 4e9fba5625c94ec2762c8b1a16a7c7cda1a2bcc0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 30 Oct 2020 09:53:44 +0200 Subject: [PATCH] Fix initializing task on argparse parse in remote mode. Do not call Task.init() to avoid auto connect, use Task.get_task instead. --- trains/backend_interface/task/args.py | 31 +++++++++++++++------------ trains/task.py | 2 +- trains/utilities/args.py | 17 +++++++++------ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index cb83ff33..ca3fbe13 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -222,7 +222,7 @@ class _Arguments(object): task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() if k.startswith(prefix) and self._exclude_parser_args.get(k[len(prefix):], True)]) - arg_parser_argeuments = {} + arg_parser_arguments = {} for k, v in task_arguments.items(): # python2 unicode support # noinspection PyBroadException @@ -255,7 +255,7 @@ class _Arguments(object): except ValueError: pass if current_action.default is not None or const_value not in (None, ''): - arg_parser_argeuments[k] = const_value + arg_parser_arguments[k] = const_value elif current_action and (current_action.nargs in ('+', '*') or isinstance(current_action.nargs, int)): try: v = yaml.load(v.strip(), Loader=yaml.SafeLoader) @@ -269,7 +269,7 @@ class _Arguments(object): v = [v_type(a) for a in v] if current_action.default is not None or v not in (None, ''): - arg_parser_argeuments[k] = v + arg_parser_arguments[k] = v except Exception: pass elif current_action and not current_action.type: @@ -286,15 +286,15 @@ class _Arguments(object): v = var_type(v) # cast back to int if it's the same value if type(current_action.default) == int and int(v) == v: - arg_parser_argeuments[k] = v = int(v) + arg_parser_arguments[k] = v = int(v) elif current_action.default is None and v in (None, ''): # Do nothing, we should leave it as is. pass else: - arg_parser_argeuments[k] = v + arg_parser_arguments[k] = v except Exception: # if we failed, leave as string - arg_parser_argeuments[k] = v + arg_parser_arguments[k] = v elif current_action and current_action.type == bool: # parser.set_defaults cannot cast string `False`/`True` to boolean properly, # so we have to do it manually here @@ -310,7 +310,7 @@ class _Arguments(object): except ValueError: pass if v not in (None, ''): - arg_parser_argeuments[k] = v + arg_parser_arguments[k] = v elif current_action and current_action.type: # if we have an action type and value (v) is None, and cannot be casted, leave as is if isinstance(current_action.type, types.FunctionType) and not v: @@ -330,17 +330,17 @@ class _Arguments(object): if bool_value is not None and current_action.default == bool(bool_value): continue - arg_parser_argeuments[k] = v + arg_parser_arguments[k] = v # noinspection PyBroadException try: if current_action.default is None and current_action.type != str and not v: - arg_parser_argeuments[k] = v = None + arg_parser_arguments[k] = v = None elif current_action.default == current_action.type(v): # this will make sure that if we have type float and default value int, # we will keep the type as int, just like the original argparser - arg_parser_argeuments[k] = v = current_action.default + arg_parser_arguments[k] = v = current_action.default else: - arg_parser_argeuments[k] = v = current_action.type(v) + arg_parser_arguments[k] = v = current_action.type(v) except Exception: pass @@ -366,11 +366,14 @@ class _Arguments(object): pass # if we already have an instance of parsed args, we should update its values + # this instance should already contain our defaults if parsed_args: - for k, v in arg_parser_argeuments.items(): - if parsed_args.get(k) is not None or v not in (None, ''): + for k, v in arg_parser_arguments.items(): + cur_v = getattr(parsed_args, k, None) + # it should not happen... + if cur_v != v and (cur_v is not None or v not in (None, '')): setattr(parsed_args, k, v) - parser.set_defaults(**arg_parser_argeuments) + parser.set_defaults(**arg_parser_arguments) def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None): # add dict prefix diff --git a/trains/task.py b/trains/task.py index 38bd5be6..5c9e2107 100644 --- a/trains/task.py +++ b/trains/task.py @@ -2190,7 +2190,7 @@ class Task(_Task): if parsed_args is None and parser == _parser: parsed_args = _parsed_args - if running_remotely() and self.is_main_task(): + if running_remotely() and (self.is_main_task() or self.id == get_remote_task_id()): self._arguments.copy_to_parser(parser, parsed_args) else: self._arguments.copy_defaults_from_argparse( diff --git a/trains/utilities/args.py b/trains/utilities/args.py index 0649d220..7900d1b9 100644 --- a/trains/utilities/args.py +++ b/trains/utilities/args.py @@ -1,7 +1,11 @@ """ Argparse utilities""" import sys from six import PY2 -from argparse import ArgumentParser, _SubParsersAction +from argparse import ArgumentParser +try: + from argparse import _SubParsersAction +except ImportError: + _SubParsersAction = type(None) class PatchArgumentParser: @@ -36,19 +40,20 @@ class PatchArgumentParser: @staticmethod def _patched_parse_args(original_parse_fn, self, args=None, namespace=None): + current_task = PatchArgumentParser._current_task # if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible. - if not PatchArgumentParser._current_task: - from ..config import running_remotely + if not current_task: + from ..config import running_remotely, get_remote_task_id if running_remotely(): # this will cause the current_task() to set PatchArgumentParser._current_task from trains import Task # noinspection PyBroadException try: - Task.init() + current_task = Task.get_task(task_id=get_remote_task_id()) except Exception: pass # automatically connect to current task: - if PatchArgumentParser._current_task: + if current_task: from ..config import running_remotely if PatchArgumentParser._calling_current_task: @@ -70,7 +75,7 @@ class PatchArgumentParser: try: # sync to/from task # noinspection PyProtectedMember - PatchArgumentParser._current_task._connect_argparse( + current_task._connect_argparse( self, args=args, namespace=namespace, parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args )