From 23bdbe4b87a144d7356ead9e96386dc5fead1f48 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 13 Dec 2023 17:50:03 +0200 Subject: [PATCH] Fix jsonargparse binding does not capture parameters before Task.init is called (#1164) --- clearml/binding/jsonargs_bind.py | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index 37de83bc..c4876c7a 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -44,7 +44,7 @@ class PatchJsonArgParse(object): cls.patch(task) @classmethod - def patch(cls, task): + def patch(cls, task=None): if ArgumentParser is None: return PatchJsonArgParse._update_task_args() @@ -73,7 +73,9 @@ class PatchJsonArgParse(object): if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v: # noinspection PyBroadException try: - if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)): + if isinstance(v, Namespace) or ( + isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v) + ): args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v)) args_type[key_with_section] = PatchJsonArgParse.namespace_type elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)): @@ -89,7 +91,7 @@ class PatchJsonArgParse(object): cls._current_task.set_parameter( cls._section_name + cls._args_sep + cls._ignore_ui_overrides, False, - description="If False, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority" + description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa ) @staticmethod @@ -111,8 +113,6 @@ class PatchJsonArgParse(object): @staticmethod def _parse_args(original_fn, obj, *args, **kwargs): - if not PatchJsonArgParse._current_task: - return original_fn(obj, *args, **kwargs) if len(args) == 1: kwargs["args"] = args[0] args = [] @@ -132,9 +132,7 @@ class PatchJsonArgParse(object): allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides) if not allow_jsonargparse_overrides_value: params_namespace = PatchJsonArgParse.__restore_args( - obj, - params_namespace, - subcommand=params_namespace.get(PatchJsonArgParse._command_name) + obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name) ) return params_namespace except Exception as e: @@ -154,6 +152,7 @@ class PatchJsonArgParse(object): except ImportError: try: import pytorch_lightning + lightning = pytorch_lightning except ImportError: lightning = None @@ -183,20 +182,14 @@ class PatchJsonArgParse(object): params_dict = t.get_parameters(backwards_compatibility=False, cast=True) for key, section_param in cls.__remote_task_params[cls._section_name].items(): if section_param.type == cls.namespace_type: - params_dict[ - "{}/{}".format(cls._section_name, key) - ] = cls._get_namespace_from_json(section_param.value) + params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value) elif section_param.type == cls.path_type: - params_dict[ - "{}/{}".format(cls._section_name, key) - ] = cls._get_path_from_json(section_param.value) + params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value) elif (not section_param.type or section_param.type == "NoneType") and not section_param.value: params_dict["{}/{}".format(cls._section_name, key)] = None skip = len(cls._section_name) + 1 cls.__remote_task_params_dict = { - k[skip:]: v - for k, v in params_dict.items() - if k.startswith(cls._section_name + cls._args_sep) + k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep) } cls.__update_remote_task_params_dict_based_on_paths(parser) @@ -205,9 +198,7 @@ class PatchJsonArgParse(object): paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict) for path in paths: args = PatchJsonArgParse.__get_args_from_path( - parser, - path, - subcommand=cls.__remote_task_params_dict.get("subcommand") + parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand") ) for subarg_key, subarg_value in args.items(): if subarg_key not in cls.__remote_task_params_dict: @@ -227,7 +218,12 @@ class PatchJsonArgParse(object): parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) if subcommand: parsed_cfg = { - ((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v + ( + (subcommand + PatchJsonArgParse._commands_sep) + if k not in PatchJsonArgParse._special_fields + else "" + ) + + k: v for k, v in parsed_cfg.items() } return parsed_cfg @@ -257,3 +253,7 @@ class PatchJsonArgParse(object): if isinstance(json_, list): return [Path(**dict_) for dict_ in json_] return Path(**json_) + + +# patch jsonargparse before anything else +PatchJsonArgParse.patch()