Fix initializing task on argparse parse in remote mode. Do not call Task.init() to avoid auto connect, use Task.get_task instead.

This commit is contained in:
allegroai 2020-10-30 09:53:44 +02:00
parent 44d9a03e56
commit 4e9fba5625
3 changed files with 29 additions and 21 deletions

View File

@ -222,7 +222,7 @@ class _Arguments(object):
task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() task_arguments = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
if k.startswith(prefix) and if k.startswith(prefix) and
self._exclude_parser_args.get(k[len(prefix):], True)]) self._exclude_parser_args.get(k[len(prefix):], True)])
arg_parser_argeuments = {} arg_parser_arguments = {}
for k, v in task_arguments.items(): for k, v in task_arguments.items():
# python2 unicode support # python2 unicode support
# noinspection PyBroadException # noinspection PyBroadException
@ -255,7 +255,7 @@ class _Arguments(object):
except ValueError: except ValueError:
pass pass
if current_action.default is not None or const_value not in (None, ''): if current_action.default is not None or const_value not in (None, ''):
arg_parser_argeuments[k] = const_value arg_parser_arguments[k] = const_value
elif current_action and (current_action.nargs in ('+', '*') or isinstance(current_action.nargs, int)): elif current_action and (current_action.nargs in ('+', '*') or isinstance(current_action.nargs, int)):
try: try:
v = yaml.load(v.strip(), Loader=yaml.SafeLoader) v = yaml.load(v.strip(), Loader=yaml.SafeLoader)
@ -269,7 +269,7 @@ class _Arguments(object):
v = [v_type(a) for a in v] v = [v_type(a) for a in v]
if current_action.default is not None or v not in (None, ''): if current_action.default is not None or v not in (None, ''):
arg_parser_argeuments[k] = v arg_parser_arguments[k] = v
except Exception: except Exception:
pass pass
elif current_action and not current_action.type: elif current_action and not current_action.type:
@ -286,15 +286,15 @@ class _Arguments(object):
v = var_type(v) v = var_type(v)
# cast back to int if it's the same value # cast back to int if it's the same value
if type(current_action.default) == int and int(v) == v: if type(current_action.default) == int and int(v) == v:
arg_parser_argeuments[k] = v = int(v) arg_parser_arguments[k] = v = int(v)
elif current_action.default is None and v in (None, ''): elif current_action.default is None and v in (None, ''):
# Do nothing, we should leave it as is. # Do nothing, we should leave it as is.
pass pass
else: else:
arg_parser_argeuments[k] = v arg_parser_arguments[k] = v
except Exception: except Exception:
# if we failed, leave as string # if we failed, leave as string
arg_parser_argeuments[k] = v arg_parser_arguments[k] = v
elif current_action and current_action.type == bool: elif current_action and current_action.type == bool:
# parser.set_defaults cannot cast string `False`/`True` to boolean properly, # parser.set_defaults cannot cast string `False`/`True` to boolean properly,
# so we have to do it manually here # so we have to do it manually here
@ -310,7 +310,7 @@ class _Arguments(object):
except ValueError: except ValueError:
pass pass
if v not in (None, ''): if v not in (None, ''):
arg_parser_argeuments[k] = v arg_parser_arguments[k] = v
elif current_action and current_action.type: elif current_action and current_action.type:
# if we have an action type and value (v) is None, and cannot be casted, leave as is # if we have an action type and value (v) is None, and cannot be casted, leave as is
if isinstance(current_action.type, types.FunctionType) and not v: if isinstance(current_action.type, types.FunctionType) and not v:
@ -330,17 +330,17 @@ class _Arguments(object):
if bool_value is not None and current_action.default == bool(bool_value): if bool_value is not None and current_action.default == bool(bool_value):
continue continue
arg_parser_argeuments[k] = v arg_parser_arguments[k] = v
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if current_action.default is None and current_action.type != str and not v: if current_action.default is None and current_action.type != str and not v:
arg_parser_argeuments[k] = v = None arg_parser_arguments[k] = v = None
elif current_action.default == current_action.type(v): elif current_action.default == current_action.type(v):
# this will make sure that if we have type float and default value int, # this will make sure that if we have type float and default value int,
# we will keep the type as int, just like the original argparser # we will keep the type as int, just like the original argparser
arg_parser_argeuments[k] = v = current_action.default arg_parser_arguments[k] = v = current_action.default
else: else:
arg_parser_argeuments[k] = v = current_action.type(v) arg_parser_arguments[k] = v = current_action.type(v)
except Exception: except Exception:
pass pass
@ -366,11 +366,14 @@ class _Arguments(object):
pass pass
# if we already have an instance of parsed args, we should update its values # if we already have an instance of parsed args, we should update its values
# this instance should already contain our defaults
if parsed_args: if parsed_args:
for k, v in arg_parser_argeuments.items(): for k, v in arg_parser_arguments.items():
if parsed_args.get(k) is not None or v not in (None, ''): cur_v = getattr(parsed_args, k, None)
# it should not happen...
if cur_v != v and (cur_v is not None or v not in (None, '')):
setattr(parsed_args, k, v) setattr(parsed_args, k, v)
parser.set_defaults(**arg_parser_argeuments) parser.set_defaults(**arg_parser_arguments)
def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None): def copy_from_dict(self, dictionary, prefix=None, descriptions=None, param_types=None):
# add dict prefix # add dict prefix

View File

@ -2190,7 +2190,7 @@ class Task(_Task):
if parsed_args is None and parser == _parser: if parsed_args is None and parser == _parser:
parsed_args = _parsed_args parsed_args = _parsed_args
if running_remotely() and self.is_main_task(): if running_remotely() and (self.is_main_task() or self.id == get_remote_task_id()):
self._arguments.copy_to_parser(parser, parsed_args) self._arguments.copy_to_parser(parser, parsed_args)
else: else:
self._arguments.copy_defaults_from_argparse( self._arguments.copy_defaults_from_argparse(

View File

@ -1,7 +1,11 @@
""" Argparse utilities""" """ Argparse utilities"""
import sys import sys
from six import PY2 from six import PY2
from argparse import ArgumentParser, _SubParsersAction from argparse import ArgumentParser
try:
from argparse import _SubParsersAction
except ImportError:
_SubParsersAction = type(None)
class PatchArgumentParser: class PatchArgumentParser:
@ -36,19 +40,20 @@ class PatchArgumentParser:
@staticmethod @staticmethod
def _patched_parse_args(original_parse_fn, self, args=None, namespace=None): def _patched_parse_args(original_parse_fn, self, args=None, namespace=None):
current_task = PatchArgumentParser._current_task
# if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible. # if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible.
if not PatchArgumentParser._current_task: if not current_task:
from ..config import running_remotely from ..config import running_remotely, get_remote_task_id
if running_remotely(): if running_remotely():
# this will cause the current_task() to set PatchArgumentParser._current_task # this will cause the current_task() to set PatchArgumentParser._current_task
from trains import Task from trains import Task
# noinspection PyBroadException # noinspection PyBroadException
try: try:
Task.init() current_task = Task.get_task(task_id=get_remote_task_id())
except Exception: except Exception:
pass pass
# automatically connect to current task: # automatically connect to current task:
if PatchArgumentParser._current_task: if current_task:
from ..config import running_remotely from ..config import running_remotely
if PatchArgumentParser._calling_current_task: if PatchArgumentParser._calling_current_task:
@ -70,7 +75,7 @@ class PatchArgumentParser:
try: try:
# sync to/from task # sync to/from task
# noinspection PyProtectedMember # noinspection PyProtectedMember
PatchArgumentParser._current_task._connect_argparse( current_task._connect_argparse(
self, args=args, namespace=namespace, self, args=args, namespace=namespace,
parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args
) )