mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Fix jsonargparse and pytorch lightning integration broken for remote execution (#403)
This commit is contained in:
parent
bf37df61aa
commit
17dfa2b92f
@ -1,18 +1,25 @@
|
|||||||
import ast
|
import json
|
||||||
import six
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from jsonargparse import ArgumentParser
|
from jsonargparse import ArgumentParser
|
||||||
from jsonargparse.namespace import Namespace
|
from jsonargparse.namespace import Namespace
|
||||||
|
from jsonargparse.util import Path
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ArgumentParser = None
|
ArgumentParser = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jsonargparse.typehints as jsonargparse_typehints
|
||||||
|
except ImportError:
|
||||||
|
jsonargparse_typehints = None
|
||||||
|
|
||||||
from ..config import running_remotely, get_remote_task_id
|
from ..config import running_remotely, get_remote_task_id
|
||||||
from .frameworks import _patched_call # noqa
|
from .frameworks import _patched_call # noqa
|
||||||
from ..utilities.proxy_object import flatten_dictionary
|
from ..utilities.proxy_object import verify_basic_type
|
||||||
|
|
||||||
|
|
||||||
class PatchJsonArgParse(object):
|
class PatchJsonArgParse(object):
|
||||||
|
namespace_type = "jsonargparse_namespace"
|
||||||
|
path_type = "jsonargparse_path"
|
||||||
_args = {}
|
_args = {}
|
||||||
_current_task = None
|
_current_task = None
|
||||||
_args_sep = "/"
|
_args_sep = "/"
|
||||||
@ -35,21 +42,49 @@ class PatchJsonArgParse(object):
|
|||||||
def patch(cls, task):
|
def patch(cls, task):
|
||||||
if ArgumentParser is None:
|
if ArgumentParser is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
PatchJsonArgParse._update_task_args()
|
PatchJsonArgParse._update_task_args()
|
||||||
|
|
||||||
if not cls.__patched:
|
if not cls.__patched:
|
||||||
cls.__patched = True
|
cls.__patched = True
|
||||||
ArgumentParser.parse_args = _patched_call(ArgumentParser.parse_args, PatchJsonArgParse._parse_args)
|
ArgumentParser.parse_args = _patched_call(ArgumentParser.parse_args, PatchJsonArgParse._parse_args)
|
||||||
|
if jsonargparse_typehints:
|
||||||
|
jsonargparse_typehints.adapt_typehints = _patched_call(
|
||||||
|
jsonargparse_typehints.adapt_typehints, PatchJsonArgParse._adapt_typehints
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _update_task_args(cls):
|
def _update_task_args(cls):
|
||||||
if running_remotely() or not cls._current_task or not cls._args:
|
if running_remotely() or not cls._current_task or not cls._args:
|
||||||
return
|
return
|
||||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
args = {}
|
||||||
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
|
args_type = {}
|
||||||
|
for k, v in cls._args.items():
|
||||||
|
key_with_section = cls._section_name + cls._args_sep + k
|
||||||
|
args[key_with_section] = v
|
||||||
|
if k in cls._args_type:
|
||||||
|
args_type[key_with_section] = cls._args_type[k]
|
||||||
|
continue
|
||||||
|
if not verify_basic_type(v) and v:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
|
||||||
|
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
|
||||||
|
args_type[key_with_section] = PatchJsonArgParse.namespace_type
|
||||||
|
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
||||||
|
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v))
|
||||||
|
args_type[key_with_section] = PatchJsonArgParse.path_type
|
||||||
|
else:
|
||||||
|
args[key_with_section] = str(v)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _adapt_typehints(original_fn, val, *args, **kwargs):
|
||||||
|
if not PatchJsonArgParse._current_task or not running_remotely():
|
||||||
|
return original_fn(val, *args, **kwargs)
|
||||||
|
return original_fn(val, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||||
if not PatchJsonArgParse._current_task:
|
if not PatchJsonArgParse._current_task:
|
||||||
@ -65,14 +100,7 @@ class PatchJsonArgParse(object):
|
|||||||
params = PatchJsonArgParse.__remote_task_params_dict
|
params = PatchJsonArgParse.__remote_task_params_dict
|
||||||
params_namespace = Namespace()
|
params_namespace = Namespace()
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if v == "":
|
params_namespace[k] = v
|
||||||
v = None
|
|
||||||
# noinspection PyBroadException
|
|
||||||
try:
|
|
||||||
v = ast.literal_eval(v)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
params_namespace[k] = PatchJsonArgParse.__namespace_eval(v)
|
|
||||||
return params_namespace
|
return params_namespace
|
||||||
except Exception:
|
except Exception:
|
||||||
return original_fn(obj, **kwargs)
|
return original_fn(obj, **kwargs)
|
||||||
@ -97,7 +125,7 @@ class PatchJsonArgParse(object):
|
|||||||
)
|
)
|
||||||
del PatchJsonArgParse._args[subcommand]
|
del PatchJsonArgParse._args[subcommand]
|
||||||
PatchJsonArgParse._args.update(subcommand_args)
|
PatchJsonArgParse._args.update(subcommand_args)
|
||||||
PatchJsonArgParse._args = {k: str(v) for k, v in PatchJsonArgParse._args.items()}
|
PatchJsonArgParse._args = {k: v for k, v in PatchJsonArgParse._args.items()}
|
||||||
PatchJsonArgParse._update_task_args()
|
PatchJsonArgParse._update_task_args()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -111,7 +139,18 @@ class PatchJsonArgParse(object):
|
|||||||
t = Task.get_task(task_id=get_remote_task_id())
|
t = Task.get_task(task_id=get_remote_task_id())
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
PatchJsonArgParse.__remote_task_params = t._get_task_property("hyperparams") or {}
|
PatchJsonArgParse.__remote_task_params = t._get_task_property("hyperparams") or {}
|
||||||
params_dict = t.get_parameters(backwards_compatibility=False)
|
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
|
||||||
|
for key, section_param in PatchJsonArgParse.__remote_task_params[PatchJsonArgParse._section_name].items():
|
||||||
|
if section_param.type == PatchJsonArgParse.namespace_type:
|
||||||
|
params_dict[
|
||||||
|
"{}/{}".format(PatchJsonArgParse._section_name, key)
|
||||||
|
] = PatchJsonArgParse._get_namespace_from_json(section_param.value)
|
||||||
|
elif section_param.type == PatchJsonArgParse.path_type:
|
||||||
|
params_dict[
|
||||||
|
"{}/{}".format(PatchJsonArgParse._section_name, key)
|
||||||
|
] = PatchJsonArgParse._get_path_from_json(section_param.value)
|
||||||
|
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
|
||||||
|
params_dict["{}/{}".format(PatchJsonArgParse._section_name, key)] = None
|
||||||
skip = len(PatchJsonArgParse._section_name) + 1
|
skip = len(PatchJsonArgParse._section_name) + 1
|
||||||
PatchJsonArgParse.__remote_task_params_dict = {
|
PatchJsonArgParse.__remote_task_params_dict = {
|
||||||
k[skip:]: v
|
k[skip:]: v
|
||||||
@ -120,15 +159,27 @@ class PatchJsonArgParse(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __namespace_eval(val):
|
def _handle_namespace(value):
|
||||||
if isinstance(val, six.string_types) and val.startswith("Namespace(") and val[-1] == ")":
|
if isinstance(value, list):
|
||||||
val = val[len("Namespace("):]
|
return [PatchJsonArgParse._handle_namespace(sub_value) for sub_value in value]
|
||||||
val = val[:-1]
|
return value.as_dict()
|
||||||
return Namespace(PatchJsonArgParse.__namespace_eval(ast.literal_eval("{" + val + "}")))
|
|
||||||
if isinstance(val, list):
|
@staticmethod
|
||||||
return [PatchJsonArgParse.__namespace_eval(v) for v in val]
|
def _handle_path(value):
|
||||||
if isinstance(val, dict):
|
if isinstance(value, list):
|
||||||
for k, v in val.items():
|
return [PatchJsonArgParse._handle_path(sub_value) for sub_value in value]
|
||||||
val[k] = PatchJsonArgParse.__namespace_eval(v)
|
return {"path": str(value.rel_path), "mode": value.mode, "cwd": None, "skip_check": value.skip_check}
|
||||||
return val
|
|
||||||
return val
|
@staticmethod
|
||||||
|
def _get_namespace_from_json(json_):
|
||||||
|
json_ = json.loads(json_)
|
||||||
|
if isinstance(json_, list):
|
||||||
|
return [Namespace(dict_) for dict_ in json_]
|
||||||
|
return Namespace(json_)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_path_from_json(json_):
|
||||||
|
json_ = json.loads(json_)
|
||||||
|
if isinstance(json_, list):
|
||||||
|
return [Path(**dict_) for dict_ in json_]
|
||||||
|
return Path(**json_)
|
||||||
|
Loading…
Reference in New Issue
Block a user