mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Add support for jsonargpraser (issue #403)
This commit is contained in:
parent
b94f345d2b
commit
4be4ba1a9a
102
clearml/binding/jsonargs_bind.py
Normal file
102
clearml/binding/jsonargs_bind.py
Normal file
@ -0,0 +1,102 @@
|
||||
""" jsonargparse binding utility functions """
|
||||
from ..config import running_remotely
|
||||
|
||||
|
||||
class PatchJsonArgParse(object):
|
||||
_original_parse_call = None
|
||||
_task = None
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, current_task):
|
||||
cls._task = current_task
|
||||
cls._patch_jsonargparse()
|
||||
|
||||
@classmethod
|
||||
def _patch_jsonargparse(cls):
|
||||
# already patched
|
||||
if cls._original_parse_call:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from jsonargparse import ArgumentParser # noqa
|
||||
cls._original_parse_call = ArgumentParser._parse_common # noqa
|
||||
ArgumentParser._parse_common = cls._patched_parse_known_args
|
||||
except Exception:
|
||||
# there is no jsonargparse
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _patched_parse_known_args(self, *args, **kwargs):
|
||||
if not PatchJsonArgParse._task:
|
||||
return PatchJsonArgParse._original_parse_call(self, *args, **kwargs)
|
||||
|
||||
try:
|
||||
from argparse import SUPPRESS
|
||||
from jsonargparse.typehints import ActionTypeHint
|
||||
from jsonargparse.actions import ActionConfigFile, _ActionSubCommands, \
|
||||
_ActionConfigLoad, filter_default_actions # noqa
|
||||
from jsonargparse.util import get_key_value_from_flat_dict, update_key_value_in_flat_dict, \
|
||||
namespace_to_dict, _dict_to_flat_namespace # noqa
|
||||
except ImportError:
|
||||
# something happened, let's just call the original
|
||||
return PatchJsonArgParse._original_parse_call(self, *args, **kwargs)
|
||||
|
||||
def cleanup_actions(cfg, actions, prefix='', skip_none=False, cast_value=False):
|
||||
for action in filter_default_actions(actions):
|
||||
action_dest = prefix + action.dest
|
||||
if (action.help == SUPPRESS and not isinstance(action, _ActionConfigLoad)) or \
|
||||
isinstance(action, ActionConfigFile) or \
|
||||
(skip_none and action_dest in cfg and cfg[action_dest] is None):
|
||||
cfg.pop(action_dest, None)
|
||||
elif isinstance(action, _ActionSubCommands):
|
||||
for key, subparser in action.choices.items():
|
||||
cleanup_actions(cfg, subparser._actions, prefix=prefix+key+'.',
|
||||
skip_none=skip_none, cast_value=cast_value)
|
||||
elif cast_value and isinstance(action, ActionTypeHint):
|
||||
value = get_key_value_from_flat_dict(cfg, action_dest)
|
||||
if value is not None and value != {}:
|
||||
if value:
|
||||
parsed_value = action._check_type(value)
|
||||
else:
|
||||
try:
|
||||
parsed_value = action._check_type(None)
|
||||
except TypeError:
|
||||
# try with original empty text
|
||||
parsed_value = action._check_type(value)
|
||||
|
||||
update_key_value_in_flat_dict(cfg, action_dest, parsed_value)
|
||||
elif cast_value and hasattr(action, 'type') and not isinstance(action, _ActionConfigLoad):
|
||||
value = get_key_value_from_flat_dict(cfg, action_dest)
|
||||
try:
|
||||
parsed_value = action.type(value or None) if action.type != str else str(value)
|
||||
update_key_value_in_flat_dict(cfg, action_dest, parsed_value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if not running_remotely():
|
||||
ret = PatchJsonArgParse._original_parse_call(self, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cfg_dict = ret if isinstance(ret, dict) else namespace_to_dict(ret)
|
||||
cfg_dict = namespace_to_dict(_dict_to_flat_namespace(cfg_dict))
|
||||
cleanup_actions(cfg_dict, actions=self._actions, skip_none=False, cast_value=False)
|
||||
except Exception:
|
||||
cfg_dict = None
|
||||
|
||||
# store / sync arguments
|
||||
if cfg_dict is not None:
|
||||
PatchJsonArgParse._task.connect(cfg_dict, name='Args')
|
||||
else:
|
||||
cfg_dict = PatchJsonArgParse._task.get_parameters_as_dict().get('Args', None)
|
||||
if cfg_dict is not None:
|
||||
if 'cfg' in kwargs:
|
||||
cleanup_actions(cfg_dict, actions=self._actions, skip_none=False, cast_value=True)
|
||||
kwargs['cfg'].update(cfg_dict)
|
||||
else:
|
||||
print('Warning failed applying jsonargparse configuration')
|
||||
|
||||
ret = PatchJsonArgParse._original_parse_call(self, *args, **kwargs)
|
||||
|
||||
return ret
|
@ -49,6 +49,7 @@ from .binding.joblib_bind import PatchedJoblib
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .binding.hydra_bind import PatchHydra
|
||||
from .binding.click_bind import PatchClick
|
||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||
from .config import (
|
||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||
deferred_config, TASK_SET_ITERATION_OFFSET, )
|
||||
@ -369,7 +370,7 @@ class Task(_Task):
|
||||
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
||||
'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True,
|
||||
'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True,
|
||||
'megengine': True,
|
||||
'megengine': True, 'jsonargparse': True,
|
||||
}
|
||||
|
||||
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots
|
||||
@ -557,6 +558,8 @@ class Task(_Task):
|
||||
is_auto_connect_frameworks_bool = not isinstance(auto_connect_frameworks, dict)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('hydra', True):
|
||||
PatchHydra.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or auto_connect_frameworks.get('jsonargparse', True):
|
||||
PatchJsonArgParse.update_current_task(task)
|
||||
if is_auto_connect_frameworks_bool or (
|
||||
auto_connect_frameworks.get('scikit', True) and
|
||||
auto_connect_frameworks.get('joblib', True)):
|
||||
|
Loading…
Reference in New Issue
Block a user