From 17dfa2b92f871287d1071118a7bb5ebf6c805fac Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 22 Dec 2022 21:47:00 +0200 Subject: [PATCH] Fix jsonargparse and pytorch lightning integration broken for remote execution (#403) --- clearml/binding/jsonargs_bind.py | 107 +++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 28 deletions(-) diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index eeba275b..0a27e786 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -1,18 +1,25 @@ -import ast -import six +import json try: from jsonargparse import ArgumentParser from jsonargparse.namespace import Namespace + from jsonargparse.util import Path except ImportError: ArgumentParser = None +try: + import jsonargparse.typehints as jsonargparse_typehints +except ImportError: + jsonargparse_typehints = None + from ..config import running_remotely, get_remote_task_id from .frameworks import _patched_call # noqa -from ..utilities.proxy_object import flatten_dictionary +from ..utilities.proxy_object import verify_basic_type class PatchJsonArgParse(object): + namespace_type = "jsonargparse_namespace" + path_type = "jsonargparse_path" _args = {} _current_task = None _args_sep = "/" @@ -35,21 +42,49 @@ class PatchJsonArgParse(object): def patch(cls, task): if ArgumentParser is None: return - PatchJsonArgParse._update_task_args() if not cls.__patched: cls.__patched = True ArgumentParser.parse_args = _patched_call(ArgumentParser.parse_args, PatchJsonArgParse._parse_args) + if jsonargparse_typehints: + jsonargparse_typehints.adapt_typehints = _patched_call( + jsonargparse_typehints.adapt_typehints, PatchJsonArgParse._adapt_typehints + ) @classmethod def _update_task_args(cls): if running_remotely() or not cls._current_task or not cls._args: return - args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} - args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()} + args = {} + args_type = {} + for k, v in cls._args.items(): + key_with_section = cls._section_name + cls._args_sep + k + args[key_with_section] = v + if k in cls._args_type: + args_type[key_with_section] = cls._args_type[k] + continue + if not verify_basic_type(v) and v: + # noinspection PyBroadException + try: + if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)): + args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v)) + args_type[key_with_section] = PatchJsonArgParse.namespace_type + elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)): + args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v)) + args_type[key_with_section] = PatchJsonArgParse.path_type + else: + args[key_with_section] = str(v) + except Exception: + pass cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type) + @staticmethod + def _adapt_typehints(original_fn, val, *args, **kwargs): + if not PatchJsonArgParse._current_task or not running_remotely(): + return original_fn(val, *args, **kwargs) + return original_fn(val, *args, **kwargs) + @staticmethod def _parse_args(original_fn, obj, *args, **kwargs): if not PatchJsonArgParse._current_task: @@ -65,14 +100,7 @@ class PatchJsonArgParse(object): params = PatchJsonArgParse.__remote_task_params_dict params_namespace = Namespace() for k, v in params.items(): - if v == "": - v = None - # noinspection PyBroadException - try: - v = ast.literal_eval(v) - except Exception: - pass - params_namespace[k] = PatchJsonArgParse.__namespace_eval(v) + params_namespace[k] = v return params_namespace except Exception: return original_fn(obj, **kwargs) @@ -97,7 +125,7 @@ class PatchJsonArgParse(object): ) del PatchJsonArgParse._args[subcommand] PatchJsonArgParse._args.update(subcommand_args) - PatchJsonArgParse._args = {k: str(v) for k, v in PatchJsonArgParse._args.items()} + PatchJsonArgParse._args = {k: v for k, v in PatchJsonArgParse._args.items()} PatchJsonArgParse._update_task_args() except Exception: pass @@ -111,7 +139,18 @@ class PatchJsonArgParse(object): t = Task.get_task(task_id=get_remote_task_id()) # noinspection PyProtectedMember PatchJsonArgParse.__remote_task_params = t._get_task_property("hyperparams") or {} - params_dict = t.get_parameters(backwards_compatibility=False) + params_dict = t.get_parameters(backwards_compatibility=False, cast=True) + for key, section_param in PatchJsonArgParse.__remote_task_params[PatchJsonArgParse._section_name].items(): + if section_param.type == PatchJsonArgParse.namespace_type: + params_dict[ + "{}/{}".format(PatchJsonArgParse._section_name, key) + ] = PatchJsonArgParse._get_namespace_from_json(section_param.value) + elif section_param.type == PatchJsonArgParse.path_type: + params_dict[ + "{}/{}".format(PatchJsonArgParse._section_name, key) + ] = PatchJsonArgParse._get_path_from_json(section_param.value) + elif (not section_param.type or section_param.type == "NoneType") and not section_param.value: + params_dict["{}/{}".format(PatchJsonArgParse._section_name, key)] = None skip = len(PatchJsonArgParse._section_name) + 1 PatchJsonArgParse.__remote_task_params_dict = { k[skip:]: v @@ -120,15 +159,27 @@ class PatchJsonArgParse(object): } @staticmethod - def __namespace_eval(val): - if isinstance(val, six.string_types) and val.startswith("Namespace(") and val[-1] == ")": - val = val[len("Namespace("):] - val = val[:-1] - return Namespace(PatchJsonArgParse.__namespace_eval(ast.literal_eval("{" + val + "}"))) - if isinstance(val, list): - return [PatchJsonArgParse.__namespace_eval(v) for v in val] - if isinstance(val, dict): - for k, v in val.items(): - val[k] = PatchJsonArgParse.__namespace_eval(v) - return val - return val + def _handle_namespace(value): + if isinstance(value, list): + return [PatchJsonArgParse._handle_namespace(sub_value) for sub_value in value] + return value.as_dict() + + @staticmethod + def _handle_path(value): + if isinstance(value, list): + return [PatchJsonArgParse._handle_path(sub_value) for sub_value in value] + return {"path": str(value.rel_path), "mode": value.mode, "cwd": None, "skip_check": value.skip_check} + + @staticmethod + def _get_namespace_from_json(json_): + json_ = json.loads(json_) + if isinstance(json_, list): + return [Namespace(dict_) for dict_ in json_] + return Namespace(json_) + + @staticmethod + def _get_path_from_json(json_): + json_ = json.loads(json_) + if isinstance(json_, list): + return [Path(**dict_) for dict_ in json_] + return Path(**json_)