Fix jsonargparse binding does not capture parameters before Task.init is called (#1164)

This commit is contained in:
allegroai 2023-12-13 17:50:03 +02:00
parent 65c6ba33e4
commit 23bdbe4b87

View File

@ -44,7 +44,7 @@ class PatchJsonArgParse(object):
cls.patch(task)
@classmethod
def patch(cls, task):
def patch(cls, task=None):
if ArgumentParser is None:
return
PatchJsonArgParse._update_task_args()
@ -73,7 +73,9 @@ class PatchJsonArgParse(object):
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
# noinspection PyBroadException
try:
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
if isinstance(v, Namespace) or (
isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)
):
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
args_type[key_with_section] = PatchJsonArgParse.namespace_type
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
@ -89,7 +91,7 @@ class PatchJsonArgParse(object):
cls._current_task.set_parameter(
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
False,
description="If False, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa
)
@staticmethod
@ -111,8 +113,6 @@ class PatchJsonArgParse(object):
@staticmethod
def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
return original_fn(obj, *args, **kwargs)
if len(args) == 1:
kwargs["args"] = args[0]
args = []
@ -132,9 +132,7 @@ class PatchJsonArgParse(object):
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args(
obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name)
)
return params_namespace
except Exception as e:
@ -154,6 +152,7 @@ class PatchJsonArgParse(object):
except ImportError:
try:
import pytorch_lightning
lightning = pytorch_lightning
except ImportError:
lightning = None
@ -183,20 +182,14 @@ class PatchJsonArgParse(object):
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
for key, section_param in cls.__remote_task_params[cls._section_name].items():
if section_param.type == cls.namespace_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_namespace_from_json(section_param.value)
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value)
elif section_param.type == cls.path_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_path_from_json(section_param.value)
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value)
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
params_dict["{}/{}".format(cls._section_name, key)] = None
skip = len(cls._section_name) + 1
cls.__remote_task_params_dict = {
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(cls._section_name + cls._args_sep)
k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep)
}
cls.__update_remote_task_params_dict_based_on_paths(parser)
@ -205,9 +198,7 @@ class PatchJsonArgParse(object):
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
for path in paths:
args = PatchJsonArgParse.__get_args_from_path(
parser,
path,
subcommand=cls.__remote_task_params_dict.get("subcommand")
parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand")
)
for subarg_key, subarg_value in args.items():
if subarg_key not in cls.__remote_task_params_dict:
@ -227,7 +218,12 @@ class PatchJsonArgParse(object):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
(
(subcommand + PatchJsonArgParse._commands_sep)
if k not in PatchJsonArgParse._special_fields
else ""
)
+ k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg
@ -257,3 +253,7 @@ class PatchJsonArgParse(object):
if isinstance(json_, list):
return [Path(**dict_) for dict_ in json_]
return Path(**json_)
# patch jsonargparse before anything else
PatchJsonArgParse.patch()