Fix jsonargparse not loading new values from changed config files

This commit is contained in:
Alex Burlacu 2023-06-01 15:12:30 +03:00
parent 56d4de04e4
commit f207e72200

View File

@ -4,7 +4,7 @@ import logging
try:
from jsonargparse import ArgumentParser
from jsonargparse.namespace import Namespace
from jsonargparse.util import Path
from jsonargparse.util import Path, change_to_path_dir
except ImportError:
ArgumentParser = None
@ -30,6 +30,7 @@ class PatchJsonArgParse(object):
_command_name = "subcommand"
_section_name = "Args"
__remote_task_params = {}
__remote_task_params_dict = {}
__patched = False
@classmethod
@ -54,7 +55,7 @@ class PatchJsonArgParse(object):
)
@classmethod
def _update_task_args(cls):
def _update_task_args(cls, parser=None, subcommand=None):
if running_remotely() or not cls._current_task or not cls._args:
return
args = {}
@ -65,7 +66,7 @@ class PatchJsonArgParse(object):
if k in cls._args_type:
args_type[key_with_section] = cls._args_type[k]
continue
if not verify_basic_type(v) and v:
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
# noinspection PyBroadException
try:
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
@ -78,8 +79,24 @@ class PatchJsonArgParse(object):
args[key_with_section] = str(v)
except Exception:
pass
args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand)
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
@classmethod
def __delete_config_args(cls, parser, args, args_type, subcommand=None):
if not parser:
return args, args_type
paths = PatchJsonArgParse.__get_paths_from_dict(cls._args)
for path in paths:
args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_delete_key, arg_to_delete_value in args_to_delete.items():
key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key
if key_with_section in args and args[key_with_section] == arg_to_delete_value:
del args[key_with_section]
if key_with_section in args_type:
del args_type[key_with_section]
return args, args_type
@staticmethod
def _adapt_typehints(original_fn, val, *args, **kwargs):
if not PatchJsonArgParse._current_task or not running_remotely():
@ -97,7 +114,7 @@ class PatchJsonArgParse(object):
return original_fn(obj, *args, **kwargs)
if running_remotely():
try:
PatchJsonArgParse._load_task_params()
PatchJsonArgParse._load_task_params(parser=obj)
params = PatchJsonArgParse.__remote_task_params_dict
params_namespace = Namespace()
for k, v in params.items():
@ -132,37 +149,68 @@ class PatchJsonArgParse(object):
del PatchJsonArgParse._args[subcommand]
PatchJsonArgParse._args.update(subcommand_args)
PatchJsonArgParse._args = {k: v for k, v in PatchJsonArgParse._args.items()}
PatchJsonArgParse._update_task_args()
PatchJsonArgParse._update_task_args(parser=obj, subcommand=subcommand)
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
return parsed_args
@staticmethod
def _load_task_params():
if not PatchJsonArgParse.__remote_task_params:
from clearml import Task
@classmethod
def _load_task_params(cls, parser=None):
if cls.__remote_task_params:
return
from clearml import Task
t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember
PatchJsonArgParse.__remote_task_params = t._get_task_property("hyperparams") or {}
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
for key, section_param in PatchJsonArgParse.__remote_task_params[PatchJsonArgParse._section_name].items():
if section_param.type == PatchJsonArgParse.namespace_type:
params_dict[
"{}/{}".format(PatchJsonArgParse._section_name, key)
] = PatchJsonArgParse._get_namespace_from_json(section_param.value)
elif section_param.type == PatchJsonArgParse.path_type:
params_dict[
"{}/{}".format(PatchJsonArgParse._section_name, key)
] = PatchJsonArgParse._get_path_from_json(section_param.value)
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
params_dict["{}/{}".format(PatchJsonArgParse._section_name, key)] = None
skip = len(PatchJsonArgParse._section_name) + 1
PatchJsonArgParse.__remote_task_params_dict = {
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(PatchJsonArgParse._section_name + PatchJsonArgParse._args_sep)
}
t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember
cls.__remote_task_params = t._get_task_property("hyperparams") or {}
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
for key, section_param in cls.__remote_task_params[cls._section_name].items():
if section_param.type == cls.namespace_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_namespace_from_json(section_param.value)
elif section_param.type == cls.path_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_path_from_json(section_param.value)
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
params_dict["{}/{}".format(cls._section_name, key)] = None
skip = len(cls._section_name) + 1
cls.__remote_task_params_dict = {
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(cls._section_name + cls._args_sep)
}
cls.__update_remote_task_params_dict_based_on_paths(parser)
@classmethod
def __update_remote_task_params_dict_based_on_paths(cls, parser):
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
for path in paths:
args = PatchJsonArgParse.__get_args_from_path(
parser,
path,
subcommand=cls.__remote_task_params_dict.get("subcommand")
)
for subarg_key, subarg_value in args.items():
if subarg_key not in cls.__remote_task_params_dict:
cls.__remote_task_params_dict[subarg_key] = subarg_value
@staticmethod
def __get_paths_from_dict(dict_):
paths = [path for path in dict_.values() if isinstance(path, Path)]
for subargs in dict_.values():
if isinstance(subargs, list) and all(isinstance(path, Path) for path in subargs):
paths.extend(subargs)
return paths
@staticmethod
def __get_args_from_path(parser, path, subcommand=None):
with change_to_path_dir(path):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {subcommand + PatchJsonArgParse._commands_sep + k: v for k, v in parsed_cfg.items()}
return parsed_cfg
@staticmethod
def _handle_namespace(value):