diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index 650124dc..df24bdb7 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -1,102 +1,166 @@ -""" jsonargparse binding utility functions """ -from ..config import running_remotely +import ast +import copy + +try: + from jsonargparse import ArgumentParser + from jsonargparse.namespace import Namespace +except ImportError: + ArgumentParser = None + +from ..config import running_remotely, get_remote_task_id +from .frameworks import _patched_call # noqa class PatchJsonArgParse(object): - _original_parse_call = None - _task = None + _args = {} + _main_task = None + _args_sep = "/" + _args_type = {} + _commands_sep = "." + _command_type = "jsonargparse.Command" + _command_name = "subcommand" + _section_name = "Args" + __remote_task_params = {} + __patched = False @classmethod - def update_current_task(cls, current_task): - cls._task = current_task - cls._patch_jsonargparse() - - @classmethod - def _patch_jsonargparse(cls): - # already patched - if cls._original_parse_call: + def patch(cls, task): + if ArgumentParser is None: return - # noinspection PyBroadException - try: - from jsonargparse import ArgumentParser # noqa - cls._original_parse_call = ArgumentParser._parse_common # noqa - ArgumentParser._parse_common = cls._patched_parse_known_args - except Exception: - # there is no jsonargparse - pass + if task: + cls._main_task = task + PatchJsonArgParse._update_task_args() + + if not cls.__patched: + cls.__patched = True + ArgumentParser.parse_args = _patched_call(ArgumentParser.parse_args, PatchJsonArgParse._parse_args) + + @classmethod + def _update_task_args(cls): + if running_remotely() or not cls._main_task or not cls._args: + return + args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} + args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()} + cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type) @staticmethod - def _patched_parse_known_args(self, *args, **kwargs): - if not PatchJsonArgParse._task: - return PatchJsonArgParse._original_parse_call(self, *args, **kwargs) - - try: - from argparse import SUPPRESS - from jsonargparse.typehints import ActionTypeHint - from jsonargparse.actions import ActionConfigFile, _ActionSubCommands, \ - _ActionConfigLoad, filter_default_actions # noqa - from jsonargparse.util import get_key_value_from_flat_dict, update_key_value_in_flat_dict, \ - namespace_to_dict, _dict_to_flat_namespace # noqa - except ImportError: - # something happened, let's just call the original - return PatchJsonArgParse._original_parse_call(self, *args, **kwargs) - - def cleanup_actions(cfg, actions, prefix='', skip_none=False, cast_value=False): - for action in filter_default_actions(actions): - action_dest = prefix + action.dest - if (action.help == SUPPRESS and not isinstance(action, _ActionConfigLoad)) or \ - isinstance(action, ActionConfigFile) or \ - (skip_none and action_dest in cfg and cfg[action_dest] is None): - cfg.pop(action_dest, None) - elif isinstance(action, _ActionSubCommands): - for key, subparser in action.choices.items(): - cleanup_actions(cfg, subparser._actions, prefix=prefix+key+'.', - skip_none=skip_none, cast_value=cast_value) - elif cast_value and isinstance(action, ActionTypeHint): - value = get_key_value_from_flat_dict(cfg, action_dest) - if value is not None and value != {}: - if value: - parsed_value = action._check_type(value) - else: - try: - parsed_value = action._check_type(None) - except TypeError: - # try with original empty text - parsed_value = action._check_type(value) - - update_key_value_in_flat_dict(cfg, action_dest, parsed_value) - elif cast_value and hasattr(action, 'type') and not isinstance(action, _ActionConfigLoad): - value = get_key_value_from_flat_dict(cfg, action_dest) - try: - parsed_value = action.type(value or None) if action.type != str else str(value) - update_key_value_in_flat_dict(cfg, action_dest, parsed_value) - except (ValueError, TypeError): - pass - - if not running_remotely(): - ret = PatchJsonArgParse._original_parse_call(self, *args, **kwargs) - - # noinspection PyBroadException + def _parse_args(original_fn, obj, *args, **kwargs): + if len(args) == 1: + kwargs["args"] = args[0] + args = [] + if len(args) > 1: + return original_fn(obj, *args, **kwargs) + if running_remotely(): try: - cfg_dict = ret if isinstance(ret, dict) else namespace_to_dict(ret) - cfg_dict = namespace_to_dict(_dict_to_flat_namespace(cfg_dict)) - cleanup_actions(cfg_dict, actions=self._actions, skip_none=False, cast_value=False) + PatchJsonArgParse._load_task_params() + params = PatchJsonArgParse.__remote_task_params_dict + for k, v in params.items(): + if v == '': + v = None + # noinspection PyBroadException + try: + v = ast.literal_eval(v) + except Exception: + pass + params[k] = v + params = PatchJsonArgParse.__unflatten_dict(params) + params = PatchJsonArgParse.__nested_dict_to_namespace(params) + return params except Exception: - cfg_dict = None - - # store / sync arguments - if cfg_dict is not None: - PatchJsonArgParse._task.connect(cfg_dict, name='Args') - else: - cfg_dict = PatchJsonArgParse._task.get_parameters_as_dict().get('Args', None) - if cfg_dict is not None: - if 'cfg' in kwargs: - cleanup_actions(cfg_dict, actions=self._actions, skip_none=False, cast_value=True) - kwargs['cfg'].update(cfg_dict) + return original_fn(obj, **kwargs) + orig_parsed_args = original_fn(obj, **kwargs) + # noinspection PyBroadException + try: + parsed_args = vars(copy.deepcopy(orig_parsed_args)) + for ns_name, ns_val in parsed_args.items(): + if not isinstance(ns_val, (Namespace, list)): + PatchJsonArgParse._args[ns_name] = str(ns_val) + if ns_name == PatchJsonArgParse._command_name: + PatchJsonArgParse._args_type[ns_name] = PatchJsonArgParse._command_type else: - print('Warning failed applying jsonargparse configuration') + ns_val = PatchJsonArgParse.__nested_namespace_to_dict(ns_val) + ns_val = PatchJsonArgParse.__flatten_dict(ns_val, parent_name=ns_name) + for k, v in ns_val.items(): + PatchJsonArgParse._args[k] = str(v) + PatchJsonArgParse._update_task_args() + except Exception: + pass + return orig_parsed_args - ret = PatchJsonArgParse._original_parse_call(self, *args, **kwargs) + @staticmethod + def _load_task_params(): + if not PatchJsonArgParse.__remote_task_params: + from clearml import Task - return ret + t = Task.get_task(task_id=get_remote_task_id()) + # noinspection PyProtectedMember + PatchJsonArgParse.__remote_task_params = t._get_task_property("hyperparams") or {} + params_dict = t.get_parameters(backwards_compatibility=False) + skip = len(PatchJsonArgParse._section_name) + 1 + PatchJsonArgParse.__remote_task_params_dict = { + k[skip:]: v + for k, v in params_dict.items() + if k.startswith(PatchJsonArgParse._section_name + PatchJsonArgParse._args_sep) + } + + @staticmethod + def __nested_namespace_to_dict(namespace): + if isinstance(namespace, list): + return [PatchJsonArgParse.__nested_namespace_to_dict(n) for n in namespace] + if not isinstance(namespace, Namespace): + return namespace + namespace = vars(namespace) + for k, v in namespace.items(): + namespace[k] = PatchJsonArgParse.__nested_namespace_to_dict(v) + return namespace + + @staticmethod + def __nested_dict_to_namespace(dict_): + if isinstance(dict_, list): + return [PatchJsonArgParse.__nested_dict_to_namespace(d) for d in dict_] + if not isinstance(dict_, dict): + return dict_ + for k, v in dict_.items(): + dict_[k] = PatchJsonArgParse.__nested_dict_to_namespace(v) + return Namespace(**dict_) + + @staticmethod + def __flatten_dict(dict_, parent_name=None): + if isinstance(dict_, list): + if parent_name: + return {parent_name: [PatchJsonArgParse.__flatten_dict(d) for d in dict_]} + return [PatchJsonArgParse.__flatten_dict(d) for d in dict_] + if not isinstance(dict_, dict): + if parent_name: + return {parent_name: dict_} + return dict_ + result = {} + for k, v in dict_.items(): + v = PatchJsonArgParse.__flatten_dict(v, parent_name=k) + if isinstance(v, dict): + for flattened_k, flattened_v in v.items(): + if parent_name: + result[parent_name + PatchJsonArgParse._commands_sep + flattened_k] = flattened_v + else: + result[flattened_k] = flattened_v + else: + result[k] = v + return result + + @staticmethod + def __unflatten_dict(dict_): + if isinstance(dict_, list): + return [PatchJsonArgParse.__unflatten_dict(d) for d in dict_] + if not isinstance(dict_, dict): + return dict_ + result = {} + for k, v in dict_.items(): + keys = k.split(PatchJsonArgParse._commands_sep) + current_dict = result + for k_part in keys[:-1]: + if k_part not in current_dict: + current_dict[k_part] = {} + current_dict = current_dict[k_part] + current_dict[keys[-1]] = PatchJsonArgParse.__unflatten_dict(v) + return result diff --git a/clearml/task.py b/clearml/task.py index ac65f967..2d4f9aee 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -372,9 +372,9 @@ class Task(_Task): auto_connect_frameworks={ 'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True, - 'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True, - 'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True, - 'megengine': True, 'jsonargparse': True, 'catboost': True + 'xgboost': True, 'scikit': True, 'fastai': True, + 'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True, + 'joblib': True, 'megengine': True, 'catboost': True } :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots @@ -561,8 +561,6 @@ class Task(_Task): is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict) if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True): PatchHydra.update_current_task(task) - if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('jsonargparse', True): - PatchJsonArgParse.update_current_task(task) if is_auto_connect_frameworks_bool or ( auto_connect_frameworks.get('scikit', True) and auto_connect_frameworks.get('joblib', True)): @@ -609,9 +607,9 @@ class Task(_Task): # Patch ArgParser to be aware of the current task argparser_update_currenttask(Task.__main_task) - # Patch Click and Fire PatchClick.patch(Task.__main_task) PatchFire.patch(Task.__main_task) + PatchJsonArgParse.patch(Task.__main_task) # set excluded arguments if isinstance(auto_connect_arg_parser, dict): diff --git a/examples/frameworks/jsonargparse/jsonargparser_command.py b/examples/frameworks/jsonargparse/jsonargparser_command.py new file mode 100644 index 00000000..7b61abe5 --- /dev/null +++ b/examples/frameworks/jsonargparse/jsonargparser_command.py @@ -0,0 +1,15 @@ +from jsonargparse import CLI +from clearml import Task + + +class Main: + def __init__(self, prize: int = 100): + self.prize = prize + + def person(self, name: str): + return "{} won {}!".format(name, self.prize) + + +if __name__ == "__main__": + Task.init(project_name="examples", task_name="jsonargparse command", auto_connect_frameworks={"pytorch_lightning": False}) + print(CLI(Main)) diff --git a/examples/frameworks/jsonargparse/jsonargparser_nested_namespaces.py b/examples/frameworks/jsonargparse/jsonargparser_nested_namespaces.py new file mode 100644 index 00000000..ddc4b259 --- /dev/null +++ b/examples/frameworks/jsonargparse/jsonargparser_nested_namespaces.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from jsonargparse import ArgumentParser +from clearml import Task + + +@dataclass +class Arg2: + opt1: str = "from default 1" + opt2: str = "from default 2" + + +if __name__ == "__main__": + Task.init(project_name="examples", task_name="jsonargparse nested namespaces", auto_connect_frameworks={"pytorch-lightning": False}) + parser = ArgumentParser() + parser.add_argument("--arg1.opt1", default="from default 1") + parser.add_argument("--arg1.opt2", default="from default 2") + parser.add_argument("--arg2", type=Arg2, default=Arg2()) + parser.add_argument("--not-nested") + print(parser.parse_args()) diff --git a/examples/frameworks/jsonargparse/requirements.txt b/examples/frameworks/jsonargparse/requirements.txt new file mode 100644 index 00000000..e14e72be --- /dev/null +++ b/examples/frameworks/jsonargparse/requirements.txt @@ -0,0 +1,2 @@ +clearml +jsonargparse \ No newline at end of file