From b8ceba38dc802f3eecdfd2fb1511632ebd8e89bf Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Oct 2023 18:44:24 +0300 Subject: [PATCH] Add more visibility when overriding jsonargparse arguments --- clearml/binding/jsonargs_bind.py | 46 ++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index a52c8edc..f0c6e902 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -28,7 +28,9 @@ class PatchJsonArgParse(object): _commands_sep = "." _command_type = "jsonargparse.Command" _command_name = "subcommand" + _special_fields = ["config", "subcommand"] _section_name = "Args" + _allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_" __remote_task_params = {} __remote_task_params_dict = {} __patched = False @@ -60,6 +62,7 @@ class PatchJsonArgParse(object): return args = {} args_type = {} + have_config_file = False for k, v in cls._args.items(): key_with_section = cls._section_name + cls._args_sep + k args[key_with_section] = v @@ -75,27 +78,18 @@ class PatchJsonArgParse(object): elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)): args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v)) args_type[key_with_section] = PatchJsonArgParse.path_type + have_config_file = True else: args[key_with_section] = str(v) except Exception: pass - args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand) cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type) - - @classmethod - def __delete_config_args(cls, parser, args, args_type, subcommand=None): - if not parser: - return args, args_type - paths = PatchJsonArgParse.__get_paths_from_dict(cls._args) - for path in paths: - args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand) - for arg_to_delete_key, arg_to_delete_value in args_to_delete.items(): - key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key - if key_with_section in args and args[key_with_section] == arg_to_delete_value: - del args[key_with_section] - if key_with_section in args_type: - del args_type[key_with_section] - return args, args_type + if have_config_file: + cls._current_task.set_parameter( + cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides, + False, + description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority" + ) @staticmethod def _adapt_typehints(original_fn, val, *args, **kwargs): @@ -103,6 +97,17 @@ class PatchJsonArgParse(object): return original_fn(val, *args, **kwargs) return original_fn(val, *args, **kwargs) + @staticmethod + def __restore_args(parser, args, subcommand=None): + paths = PatchJsonArgParse.__get_paths_from_dict(args) + for path in paths: + args_to_restore = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand) + for arg_to_restore_key, arg_to_restore_value in args_to_restore.items(): + if arg_to_restore_key in PatchJsonArgParse._special_fields: + continue + args[arg_to_restore_key] = arg_to_restore_value + return args + @staticmethod def _parse_args(original_fn, obj, *args, **kwargs): if not PatchJsonArgParse._current_task: @@ -119,6 +124,13 @@ class PatchJsonArgParse(object): params_namespace = Namespace() for k, v in params.items(): params_namespace[k] = v + allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True) + if not allow_jsonargparse_overrides_value: + params_namespace = PatchJsonArgParse.__restore_args( + obj, + params_namespace, + subcommand=params_namespace.get(PatchJsonArgParse._command_name) + ) return params_namespace except Exception as e: logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e)) @@ -210,7 +222,7 @@ class PatchJsonArgParse(object): parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) if subcommand: parsed_cfg = { - ((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v + ((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v for k, v in parsed_cfg.items() } return parsed_cfg