mirror of
https://github.com/clearml/clearml
synced 2025-02-12 07:35:08 +00:00
Fix jsonargparse binding does not capture parameters before Task.init is called (#1164)
This commit is contained in:
parent
65c6ba33e4
commit
23bdbe4b87
@ -44,7 +44,7 @@ class PatchJsonArgParse(object):
|
|||||||
cls.patch(task)
|
cls.patch(task)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def patch(cls, task):
|
def patch(cls, task=None):
|
||||||
if ArgumentParser is None:
|
if ArgumentParser is None:
|
||||||
return
|
return
|
||||||
PatchJsonArgParse._update_task_args()
|
PatchJsonArgParse._update_task_args()
|
||||||
@ -73,7 +73,9 @@ class PatchJsonArgParse(object):
|
|||||||
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
|
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
|
if isinstance(v, Namespace) or (
|
||||||
|
isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)
|
||||||
|
):
|
||||||
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
|
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
|
||||||
args_type[key_with_section] = PatchJsonArgParse.namespace_type
|
args_type[key_with_section] = PatchJsonArgParse.namespace_type
|
||||||
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
||||||
@ -89,7 +91,7 @@ class PatchJsonArgParse(object):
|
|||||||
cls._current_task.set_parameter(
|
cls._current_task.set_parameter(
|
||||||
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
|
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
|
||||||
False,
|
False,
|
||||||
description="If False, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
|
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -111,8 +113,6 @@ class PatchJsonArgParse(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||||
if not PatchJsonArgParse._current_task:
|
|
||||||
return original_fn(obj, *args, **kwargs)
|
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
kwargs["args"] = args[0]
|
kwargs["args"] = args[0]
|
||||||
args = []
|
args = []
|
||||||
@ -132,9 +132,7 @@ class PatchJsonArgParse(object):
|
|||||||
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
|
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
|
||||||
if not allow_jsonargparse_overrides_value:
|
if not allow_jsonargparse_overrides_value:
|
||||||
params_namespace = PatchJsonArgParse.__restore_args(
|
params_namespace = PatchJsonArgParse.__restore_args(
|
||||||
obj,
|
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||||
params_namespace,
|
|
||||||
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
|
||||||
)
|
)
|
||||||
return params_namespace
|
return params_namespace
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -154,6 +152,7 @@ class PatchJsonArgParse(object):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
import pytorch_lightning
|
import pytorch_lightning
|
||||||
|
|
||||||
lightning = pytorch_lightning
|
lightning = pytorch_lightning
|
||||||
except ImportError:
|
except ImportError:
|
||||||
lightning = None
|
lightning = None
|
||||||
@ -183,20 +182,14 @@ class PatchJsonArgParse(object):
|
|||||||
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
|
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
|
||||||
for key, section_param in cls.__remote_task_params[cls._section_name].items():
|
for key, section_param in cls.__remote_task_params[cls._section_name].items():
|
||||||
if section_param.type == cls.namespace_type:
|
if section_param.type == cls.namespace_type:
|
||||||
params_dict[
|
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value)
|
||||||
"{}/{}".format(cls._section_name, key)
|
|
||||||
] = cls._get_namespace_from_json(section_param.value)
|
|
||||||
elif section_param.type == cls.path_type:
|
elif section_param.type == cls.path_type:
|
||||||
params_dict[
|
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value)
|
||||||
"{}/{}".format(cls._section_name, key)
|
|
||||||
] = cls._get_path_from_json(section_param.value)
|
|
||||||
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
|
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
|
||||||
params_dict["{}/{}".format(cls._section_name, key)] = None
|
params_dict["{}/{}".format(cls._section_name, key)] = None
|
||||||
skip = len(cls._section_name) + 1
|
skip = len(cls._section_name) + 1
|
||||||
cls.__remote_task_params_dict = {
|
cls.__remote_task_params_dict = {
|
||||||
k[skip:]: v
|
k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep)
|
||||||
for k, v in params_dict.items()
|
|
||||||
if k.startswith(cls._section_name + cls._args_sep)
|
|
||||||
}
|
}
|
||||||
cls.__update_remote_task_params_dict_based_on_paths(parser)
|
cls.__update_remote_task_params_dict_based_on_paths(parser)
|
||||||
|
|
||||||
@ -205,9 +198,7 @@ class PatchJsonArgParse(object):
|
|||||||
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
|
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
|
||||||
for path in paths:
|
for path in paths:
|
||||||
args = PatchJsonArgParse.__get_args_from_path(
|
args = PatchJsonArgParse.__get_args_from_path(
|
||||||
parser,
|
parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand")
|
||||||
path,
|
|
||||||
subcommand=cls.__remote_task_params_dict.get("subcommand")
|
|
||||||
)
|
)
|
||||||
for subarg_key, subarg_value in args.items():
|
for subarg_key, subarg_value in args.items():
|
||||||
if subarg_key not in cls.__remote_task_params_dict:
|
if subarg_key not in cls.__remote_task_params_dict:
|
||||||
@ -227,7 +218,12 @@ class PatchJsonArgParse(object):
|
|||||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||||
if subcommand:
|
if subcommand:
|
||||||
parsed_cfg = {
|
parsed_cfg = {
|
||||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
|
(
|
||||||
|
(subcommand + PatchJsonArgParse._commands_sep)
|
||||||
|
if k not in PatchJsonArgParse._special_fields
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
+ k: v
|
||||||
for k, v in parsed_cfg.items()
|
for k, v in parsed_cfg.items()
|
||||||
}
|
}
|
||||||
return parsed_cfg
|
return parsed_cfg
|
||||||
@ -257,3 +253,7 @@ class PatchJsonArgParse(object):
|
|||||||
if isinstance(json_, list):
|
if isinstance(json_, list):
|
||||||
return [Path(**dict_) for dict_ in json_]
|
return [Path(**dict_) for dict_ in json_]
|
||||||
return Path(**json_)
|
return Path(**json_)
|
||||||
|
|
||||||
|
|
||||||
|
# patch jsonargparse before anything else
|
||||||
|
PatchJsonArgParse.patch()
|
||||||
|
Loading…
Reference in New Issue
Block a user