Fix server updated with the argparse in remote before Task.init() is called (respect skipped args)

Fix nonstandard argparse with default value that is not of defined type
This commit is contained in:
allegroai 2021-02-21 14:56:30 +02:00
parent 41230ac2c7
commit ff7e55756c
2 changed files with 33 additions and 6 deletions

View File

@ -356,6 +356,9 @@ class _Arguments(object):
bool_value = cast_to_bool_int(v, strip=True) bool_value = cast_to_bool_int(v, strip=True)
if bool_value is not None and current_action.default == bool(bool_value): if bool_value is not None and current_action.default == bool(bool_value):
continue continue
elif str(current_action.default) == v:
# if we changed nothing, leave it as is (i.e. default value)
v = current_action.default
arg_parser_arguments[k] = v arg_parser_arguments[k] = v
# noinspection PyBroadException # noinspection PyBroadException
@ -397,7 +400,7 @@ class _Arguments(object):
pass pass
# if API supports sections, we can update back the Args section with all the missing default # if API supports sections, we can update back the Args section with all the missing default
if Session.check_min_api_version('2.9'): if Session.check_min_api_version('2.9') and not self._exclude_parser_args.get('*', None):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
task_defaults, task_defaults_descriptions, task_defaults_types = \ task_defaults, task_defaults_descriptions, task_defaults_types = \

View File

@ -21,6 +21,7 @@ class PatchArgumentParser:
_calling_current_task = False _calling_current_task = False
_last_parsed_args = None _last_parsed_args = None
_last_arg_parser = None _last_arg_parser = None
_recursion_guard = False
@staticmethod @staticmethod
def add_subparsers(self, **kwargs): def add_subparsers(self, **kwargs):
@ -34,13 +35,31 @@ class PatchArgumentParser:
@staticmethod @staticmethod
def parse_args(self, args=None, namespace=None): def parse_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args, if PatchArgumentParser._recursion_guard:
self, args=args, namespace=namespace) return {} if not PatchArgumentParser._original_parse_args else \
PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace)
PatchArgumentParser._recursion_guard = True
try:
result = PatchArgumentParser._patched_parse_args(
PatchArgumentParser._original_parse_args, self, args=args, namespace=namespace)
finally:
PatchArgumentParser._recursion_guard = False
return result
@staticmethod @staticmethod
def parse_known_args(self, args=None, namespace=None): def parse_known_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args, if PatchArgumentParser._recursion_guard:
self, args=args, namespace=namespace) return {} if not PatchArgumentParser._original_parse_args else \
PatchArgumentParser._original_parse_known_args(self, args=args, namespace=namespace)
PatchArgumentParser._recursion_guard = True
try:
result = PatchArgumentParser._patched_parse_args(
PatchArgumentParser._original_parse_known_args, self, args=args, namespace=namespace)
finally:
PatchArgumentParser._recursion_guard = False
return result
@staticmethod @staticmethod
def _patched_parse_args(original_parse_fn, self, args=None, namespace=None): def _patched_parse_args(original_parse_fn, self, args=None, namespace=None):
@ -54,6 +73,11 @@ class PatchArgumentParser:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
current_task = Task.get_task(task_id=get_remote_task_id()) current_task = Task.get_task(task_id=get_remote_task_id())
# make sure we do not store back the values
# (we will do that when we actually call parse args)
# this will make sure that if we have args we should not track we know them
# noinspection PyProtectedMember
current_task._arguments.exclude_parser_args({'*': True})
except Exception: except Exception:
pass pass
# automatically connect to current task: # automatically connect to current task:
@ -147,7 +171,7 @@ class PatchArgumentParser:
parsed_args_namespace = copy(parsed_args) parsed_args_namespace = copy(parsed_args)
parsed_args = (parsed_args_namespace, []) parsed_args = (parsed_args_namespace, [])
# cast arguments in parsed_args_namespace entries to str # cast arguments in parsed_args_namespace entries to str
if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace): if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace):
for k, v in parser._parsed_arg_string_lookup.items(): # noqa for k, v in parser._parsed_arg_string_lookup.items(): # noqa
if hasattr(parsed_args_namespace, k): if hasattr(parsed_args_namespace, k):