Fix server updated with the argparse in remote before Task.init() is called (respect skipped args)

Fix nonstandard argparse with default value that is not of defined type
This commit is contained in:
allegroai 2021-02-21 14:56:30 +02:00
parent 41230ac2c7
commit ff7e55756c
2 changed files with 33 additions and 6 deletions

View File

@ -356,6 +356,9 @@ class _Arguments(object):
bool_value = cast_to_bool_int(v, strip=True)
if bool_value is not None and current_action.default == bool(bool_value):
continue
elif str(current_action.default) == v:
# if we changed nothing, leave it as is (i.e. default value)
v = current_action.default
arg_parser_arguments[k] = v
# noinspection PyBroadException
@ -397,7 +400,7 @@ class _Arguments(object):
pass
# if API supports sections, we can update back the Args section with all the missing default
if Session.check_min_api_version('2.9'):
if Session.check_min_api_version('2.9') and not self._exclude_parser_args.get('*', None):
# noinspection PyBroadException
try:
task_defaults, task_defaults_descriptions, task_defaults_types = \

View File

@ -21,6 +21,7 @@ class PatchArgumentParser:
_calling_current_task = False
_last_parsed_args = None
_last_arg_parser = None
_recursion_guard = False
@staticmethod
def add_subparsers(self, **kwargs):
@ -34,13 +35,31 @@ class PatchArgumentParser:
@staticmethod
def parse_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args,
self, args=args, namespace=namespace)
if PatchArgumentParser._recursion_guard:
return {} if not PatchArgumentParser._original_parse_args else \
PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace)
PatchArgumentParser._recursion_guard = True
try:
result = PatchArgumentParser._patched_parse_args(
PatchArgumentParser._original_parse_args, self, args=args, namespace=namespace)
finally:
PatchArgumentParser._recursion_guard = False
return result
@staticmethod
def parse_known_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args,
self, args=args, namespace=namespace)
if PatchArgumentParser._recursion_guard:
return {} if not PatchArgumentParser._original_parse_args else \
PatchArgumentParser._original_parse_known_args(self, args=args, namespace=namespace)
PatchArgumentParser._recursion_guard = True
try:
result = PatchArgumentParser._patched_parse_args(
PatchArgumentParser._original_parse_known_args, self, args=args, namespace=namespace)
finally:
PatchArgumentParser._recursion_guard = False
return result
@staticmethod
def _patched_parse_args(original_parse_fn, self, args=None, namespace=None):
@ -54,6 +73,11 @@ class PatchArgumentParser:
# noinspection PyBroadException
try:
current_task = Task.get_task(task_id=get_remote_task_id())
# make sure we do not store back the values
# (we will do that when we actually call parse args)
# this will make sure that if we have args we should not track we know them
# noinspection PyProtectedMember
current_task._arguments.exclude_parser_args({'*': True})
except Exception:
pass
# automatically connect to current task: