mirror of
https://github.com/clearml/clearml
synced 2025-02-07 05:18:50 +00:00
Fix jsonargparse binding does not capture parameters before Task.init is called (#1164)
This commit is contained in:
parent
65c6ba33e4
commit
23bdbe4b87
@ -44,7 +44,7 @@ class PatchJsonArgParse(object):
|
||||
cls.patch(task)
|
||||
|
||||
@classmethod
|
||||
def patch(cls, task):
|
||||
def patch(cls, task=None):
|
||||
if ArgumentParser is None:
|
||||
return
|
||||
PatchJsonArgParse._update_task_args()
|
||||
@ -73,7 +73,9 @@ class PatchJsonArgParse(object):
|
||||
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
|
||||
if isinstance(v, Namespace) or (
|
||||
isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)
|
||||
):
|
||||
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
|
||||
args_type[key_with_section] = PatchJsonArgParse.namespace_type
|
||||
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
||||
@ -89,7 +91,7 @@ class PatchJsonArgParse(object):
|
||||
cls._current_task.set_parameter(
|
||||
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
|
||||
False,
|
||||
description="If False, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
|
||||
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -111,8 +113,6 @@ class PatchJsonArgParse(object):
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||
if not PatchJsonArgParse._current_task:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
if len(args) == 1:
|
||||
kwargs["args"] = args[0]
|
||||
args = []
|
||||
@ -132,9 +132,7 @@ class PatchJsonArgParse(object):
|
||||
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
|
||||
if not allow_jsonargparse_overrides_value:
|
||||
params_namespace = PatchJsonArgParse.__restore_args(
|
||||
obj,
|
||||
params_namespace,
|
||||
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||
)
|
||||
return params_namespace
|
||||
except Exception as e:
|
||||
@ -154,6 +152,7 @@ class PatchJsonArgParse(object):
|
||||
except ImportError:
|
||||
try:
|
||||
import pytorch_lightning
|
||||
|
||||
lightning = pytorch_lightning
|
||||
except ImportError:
|
||||
lightning = None
|
||||
@ -183,20 +182,14 @@ class PatchJsonArgParse(object):
|
||||
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
|
||||
for key, section_param in cls.__remote_task_params[cls._section_name].items():
|
||||
if section_param.type == cls.namespace_type:
|
||||
params_dict[
|
||||
"{}/{}".format(cls._section_name, key)
|
||||
] = cls._get_namespace_from_json(section_param.value)
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value)
|
||||
elif section_param.type == cls.path_type:
|
||||
params_dict[
|
||||
"{}/{}".format(cls._section_name, key)
|
||||
] = cls._get_path_from_json(section_param.value)
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value)
|
||||
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = None
|
||||
skip = len(cls._section_name) + 1
|
||||
cls.__remote_task_params_dict = {
|
||||
k[skip:]: v
|
||||
for k, v in params_dict.items()
|
||||
if k.startswith(cls._section_name + cls._args_sep)
|
||||
k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep)
|
||||
}
|
||||
cls.__update_remote_task_params_dict_based_on_paths(parser)
|
||||
|
||||
@ -205,9 +198,7 @@ class PatchJsonArgParse(object):
|
||||
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
|
||||
for path in paths:
|
||||
args = PatchJsonArgParse.__get_args_from_path(
|
||||
parser,
|
||||
path,
|
||||
subcommand=cls.__remote_task_params_dict.get("subcommand")
|
||||
parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand")
|
||||
)
|
||||
for subarg_key, subarg_value in args.items():
|
||||
if subarg_key not in cls.__remote_task_params_dict:
|
||||
@ -227,7 +218,12 @@ class PatchJsonArgParse(object):
|
||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||
if subcommand:
|
||||
parsed_cfg = {
|
||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
|
||||
(
|
||||
(subcommand + PatchJsonArgParse._commands_sep)
|
||||
if k not in PatchJsonArgParse._special_fields
|
||||
else ""
|
||||
)
|
||||
+ k: v
|
||||
for k, v in parsed_cfg.items()
|
||||
}
|
||||
return parsed_cfg
|
||||
@ -257,3 +253,7 @@ class PatchJsonArgParse(object):
|
||||
if isinstance(json_, list):
|
||||
return [Path(**dict_) for dict_ in json_]
|
||||
return Path(**json_)
|
||||
|
||||
|
||||
# patch jsonargparse before anything else
|
||||
PatchJsonArgParse.patch()
|
||||
|
Loading…
Reference in New Issue
Block a user