diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index 7affbf1d..5a9d64f5 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -1,17 +1,35 @@ import json +import copy import logging try: + # public import capabilities of namespace, util, actions will be deprecated + # import from "protected" instead + + from jsonargparse._namespace import Namespace + # noinspection PyProtectedMember + from jsonargparse._util import Path + # noinspection PyProtectedMember from jsonargparse import ArgumentParser - from jsonargparse.namespace import Namespace - from jsonargparse.util import Path, change_to_path_dir except ImportError: - ArgumentParser = None + try: + from jsonargparse.namespace import Namespace + from jsonargparse.util import Path + from jsonargparse import ArgumentParser + except ImportError: + ArgumentParser = None try: - import jsonargparse.typehints as jsonargparse_typehints + # public import capabilities of jsonargparse_typehints will be deprecated + # import from "protected" instead + + # noinspection PyProtectedMember + import jsonargparse._typehints as jsonargparse_typehints except ImportError: - jsonargparse_typehints = None + try: + import jsonargparse.typehints as jsonargparse_typehints + except ImportError: + jsonargparse_typehints = None from ..config import running_remotely, get_remote_task_id from .frameworks import _patched_call # noqa @@ -122,17 +140,19 @@ class PatchJsonArgParse(object): try: PatchJsonArgParse._load_task_params(parser=obj) params = PatchJsonArgParse.__remote_task_params_dict - params_namespace = Namespace() - for k, v in params.items(): - params_namespace[k] = v allow_jsonargparse_overrides_value = True if PatchJsonArgParse._allow_jsonargparse_overrides in params: allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides) if PatchJsonArgParse._ignore_ui_overrides in params: allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides) + params_namespace = Namespace() + for k, v in params.items(): + params_namespace[k] = v if not allow_jsonargparse_overrides_value: params_namespace = PatchJsonArgParse.__restore_args( - obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name) + obj, + params_namespace, + subcommand=params_namespace.get(PatchJsonArgParse._command_name) ) if PatchJsonArgParse._allow_jsonargparse_overrides in params_namespace: del params_namespace[PatchJsonArgParse._allow_jsonargparse_overrides] @@ -210,27 +230,37 @@ class PatchJsonArgParse(object): @staticmethod def __get_paths_from_dict(dict_): - paths = [path for path in dict_.values() if isinstance(path, Path)] - for subargs in dict_.values(): + paths = [(path_key, path) for path_key, path in dict_.items() if isinstance(path, Path)] + for subargs_key, subargs in dict_.items(): if isinstance(subargs, list) and all(isinstance(path, Path) for path in subargs): - paths.extend(subargs) + paths.extend((subargs_key, path) for path in subargs) return paths @staticmethod def __get_args_from_path(parser, path, subcommand=None): - with change_to_path_dir(path): - parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) - if subcommand: - parsed_cfg = { - ( - (subcommand + PatchJsonArgParse._commands_sep) - if k not in PatchJsonArgParse._special_fields - else "" - ) - + k: v - for k, v in parsed_cfg.items() - } - return parsed_cfg + try: + # make sure no side effects happen in parser + parser = copy.deepcopy(parser) + argument = path[0] + if subcommand and argument.startswith(subcommand + PatchJsonArgParse._commands_sep): + argument = argument[len(subcommand + PatchJsonArgParse._commands_sep):] + result = parser.parse_args( + [subcommand, parser.prefix_chars[0] * 2 + argument, path[1].rel_path], + _skip_check=True, + defaults=False, + ) + if PatchJsonArgParse._command_name in result: + del result[PatchJsonArgParse._command_name] + else: + result = parser.parse_args( + [parser.prefix_chars[0] * 2 + argument, path[1].rel_path], _skip_check=True, defaults=False + ) + if argument in result: + del result[argument] + return result + except Exception as e: + logging.getLogger(__file__).warning("Failed parsing jsonargparse config: {}".format(e)) + return Namespace() @staticmethod def _handle_namespace(value):