Add more visibility when overriding jsonargparse arguments

This commit is contained in:
allegroai 2023-10-24 18:44:24 +03:00
parent f2057febd0
commit b8ceba38dc

View File

@ -28,7 +28,9 @@ class PatchJsonArgParse(object):
_commands_sep = "."
_command_type = "jsonargparse.Command"
_command_name = "subcommand"
_special_fields = ["config", "subcommand"]
_section_name = "Args"
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
__remote_task_params = {}
__remote_task_params_dict = {}
__patched = False
@ -60,6 +62,7 @@ class PatchJsonArgParse(object):
return
args = {}
args_type = {}
have_config_file = False
for k, v in cls._args.items():
key_with_section = cls._section_name + cls._args_sep + k
args[key_with_section] = v
@ -75,27 +78,18 @@ class PatchJsonArgParse(object):
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v))
args_type[key_with_section] = PatchJsonArgParse.path_type
have_config_file = True
else:
args[key_with_section] = str(v)
except Exception:
pass
args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand)
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
@classmethod
def __delete_config_args(cls, parser, args, args_type, subcommand=None):
if not parser:
return args, args_type
paths = PatchJsonArgParse.__get_paths_from_dict(cls._args)
for path in paths:
args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_delete_key, arg_to_delete_value in args_to_delete.items():
key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key
if key_with_section in args and args[key_with_section] == arg_to_delete_value:
del args[key_with_section]
if key_with_section in args_type:
del args_type[key_with_section]
return args, args_type
if have_config_file:
cls._current_task.set_parameter(
cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
False,
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
)
@staticmethod
def _adapt_typehints(original_fn, val, *args, **kwargs):
@ -103,6 +97,17 @@ class PatchJsonArgParse(object):
return original_fn(val, *args, **kwargs)
return original_fn(val, *args, **kwargs)
@staticmethod
def __restore_args(parser, args, subcommand=None):
paths = PatchJsonArgParse.__get_paths_from_dict(args)
for path in paths:
args_to_restore = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_restore_key, arg_to_restore_value in args_to_restore.items():
if arg_to_restore_key in PatchJsonArgParse._special_fields:
continue
args[arg_to_restore_key] = arg_to_restore_value
return args
@staticmethod
def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
@ -119,6 +124,13 @@ class PatchJsonArgParse(object):
params_namespace = Namespace()
for k, v in params.items():
params_namespace[k] = v
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args(
obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
)
return params_namespace
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
@ -210,7 +222,7 @@ class PatchJsonArgParse(object):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg