From ff7e55756c991936df47bc43e13f22412bf532c9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 21 Feb 2021 14:56:30 +0200 Subject: [PATCH] Fix server updated with the argparse in remote before Task.init() is called (respect skipped args) Fix nonstandard argparse with default value that is not of defined type --- clearml/backend_interface/task/args.py | 5 +++- clearml/utilities/args.py | 34 ++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/clearml/backend_interface/task/args.py b/clearml/backend_interface/task/args.py index 1e83b7fd..02383917 100644 --- a/clearml/backend_interface/task/args.py +++ b/clearml/backend_interface/task/args.py @@ -356,6 +356,9 @@ class _Arguments(object): bool_value = cast_to_bool_int(v, strip=True) if bool_value is not None and current_action.default == bool(bool_value): continue + elif str(current_action.default) == v: + # if we changed nothing, leave it as is (i.e. default value) + v = current_action.default arg_parser_arguments[k] = v # noinspection PyBroadException @@ -397,7 +400,7 @@ class _Arguments(object): pass # if API supports sections, we can update back the Args section with all the missing default - if Session.check_min_api_version('2.9'): + if Session.check_min_api_version('2.9') and not self._exclude_parser_args.get('*', None): # noinspection PyBroadException try: task_defaults, task_defaults_descriptions, task_defaults_types = \ diff --git a/clearml/utilities/args.py b/clearml/utilities/args.py index b806ca72..367131b3 100644 --- a/clearml/utilities/args.py +++ b/clearml/utilities/args.py @@ -21,6 +21,7 @@ class PatchArgumentParser: _calling_current_task = False _last_parsed_args = None _last_arg_parser = None + _recursion_guard = False @staticmethod def add_subparsers(self, **kwargs): @@ -34,13 +35,31 @@ class PatchArgumentParser: @staticmethod def parse_args(self, args=None, namespace=None): - return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args, - self, args=args, namespace=namespace) + if PatchArgumentParser._recursion_guard: + return {} if not PatchArgumentParser._original_parse_args else \ + PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace) + + PatchArgumentParser._recursion_guard = True + try: + result = PatchArgumentParser._patched_parse_args( + PatchArgumentParser._original_parse_args, self, args=args, namespace=namespace) + finally: + PatchArgumentParser._recursion_guard = False + return result @staticmethod def parse_known_args(self, args=None, namespace=None): - return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args, - self, args=args, namespace=namespace) + if PatchArgumentParser._recursion_guard: + return {} if not PatchArgumentParser._original_parse_args else \ + PatchArgumentParser._original_parse_known_args(self, args=args, namespace=namespace) + + PatchArgumentParser._recursion_guard = True + try: + result = PatchArgumentParser._patched_parse_args( + PatchArgumentParser._original_parse_known_args, self, args=args, namespace=namespace) + finally: + PatchArgumentParser._recursion_guard = False + return result @staticmethod def _patched_parse_args(original_parse_fn, self, args=None, namespace=None): @@ -54,6 +73,11 @@ class PatchArgumentParser: # noinspection PyBroadException try: current_task = Task.get_task(task_id=get_remote_task_id()) + # make sure we do not store back the values + # (we will do that when we actually call parse args) + # this will make sure that if we have args we should not track we know them + # noinspection PyProtectedMember + current_task._arguments.exclude_parser_args({'*': True}) except Exception: pass # automatically connect to current task: @@ -147,7 +171,7 @@ class PatchArgumentParser: parsed_args_namespace = copy(parsed_args) parsed_args = (parsed_args_namespace, []) - # cast arguments in parsed_args_namespace entries to str + # cast arguments in parsed_args_namespace entries to str if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace): for k, v in parser._parsed_arg_string_lookup.items(): # noqa if hasattr(parsed_args_namespace, k):