mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Fix argparse support to store consistent str representation of custom objects. Avoid changing default value if remote value matches.
Fix argsparse type as function
This commit is contained in:
parent
3cb42acd12
commit
365c79326a
@ -1,9 +1,9 @@
|
||||
import yaml
|
||||
|
||||
from inspect import isfunction
|
||||
from six import PY2
|
||||
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS # noqa
|
||||
from copy import copy
|
||||
import types
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...utilities.args import call_original_argparser
|
||||
@ -95,9 +95,7 @@ class _Arguments(object):
|
||||
else:
|
||||
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
|
||||
defaults_ = {
|
||||
a.dest: args_dict.get(a.dest) if (
|
||||
args_dict.get(a.dest) is not None and not callable(args_dict.get(a.dest))
|
||||
) else '' for a in actions
|
||||
a.dest: cls.__cast_arg(args_dict.get(a.dest)) for a in actions
|
||||
}
|
||||
except Exception:
|
||||
# don't crash us if we failed parsing the inputs
|
||||
@ -294,6 +292,11 @@ class _Arguments(object):
|
||||
elif var_type == type(None): # noqa: E721 - do not change!
|
||||
# because isinstance(var_type, type(None)) === False
|
||||
var_type = str
|
||||
elif var_type not in (bool, int, float, str, dict, list, tuple) and \
|
||||
str(self.__cast_arg(current_action.default)) == str(v):
|
||||
# there is nothing we can Do, just leave with the default
|
||||
continue
|
||||
|
||||
# now we should try and cast the value if we can
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -327,7 +330,7 @@ class _Arguments(object):
|
||||
arg_parser_arguments[k] = v
|
||||
elif current_action and current_action.type:
|
||||
# if we have an action type and value (v) is None, and cannot be casted, leave as is
|
||||
if isinstance(current_action.type, types.FunctionType) and not v: # noqa
|
||||
if isfunction(current_action.type) and not v: # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
v = current_action.type(v)
|
||||
@ -349,7 +352,10 @@ class _Arguments(object):
|
||||
try:
|
||||
if current_action.default is None and current_action.type != str and not v:
|
||||
arg_parser_arguments[k] = v = None
|
||||
elif current_action.default == current_action.type(v):
|
||||
elif (not isfunction(current_action.type)
|
||||
and current_action.default == current_action.type(v)) \
|
||||
or (isfunction(current_action.type) and
|
||||
str(self.__cast_arg(current_action.default)) == str(v)):
|
||||
# this will make sure that if we have type float and default value int,
|
||||
# we will keep the type as int, just like the original argparser
|
||||
arg_parser_arguments[k] = v = current_action.default
|
||||
@ -523,3 +529,12 @@ class _Arguments(object):
|
||||
if not isinstance(dictionary, self._ProxyDictReadOnly):
|
||||
return self._ProxyDictReadOnly(self, prefix, **dictionary)
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def __cast_arg(cls, arg):
|
||||
if arg is None or callable(arg):
|
||||
return ''
|
||||
# If this an instance, just store the type
|
||||
if str(hex(id(arg))) in str(arg):
|
||||
return str(type(arg))
|
||||
return arg
|
||||
|
@ -1,7 +1,10 @@
|
||||
""" Argparse utilities"""
|
||||
import sys
|
||||
from copy import copy
|
||||
|
||||
from six import PY2
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
try:
|
||||
from argparse import _SubParsersAction
|
||||
except ImportError:
|
||||
@ -12,6 +15,7 @@ class PatchArgumentParser:
|
||||
_original_parse_args = None
|
||||
_original_parse_known_args = None
|
||||
_original_add_subparsers = None
|
||||
_original_get_value = None
|
||||
_add_subparsers_counter = 0
|
||||
_current_task = None
|
||||
_calling_current_task = False
|
||||
@ -58,18 +62,20 @@ class PatchArgumentParser:
|
||||
|
||||
if PatchArgumentParser._calling_current_task:
|
||||
# if we are here and running remotely by now we should try to parse the arguments
|
||||
parsed_args = None
|
||||
if original_parse_fn:
|
||||
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace))
|
||||
return PatchArgumentParser._last_parsed_args[-1]
|
||||
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
|
||||
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
|
||||
return parsed_args or PatchArgumentParser._last_parsed_args[-1]
|
||||
|
||||
PatchArgumentParser._calling_current_task = True
|
||||
# Store last instance and result
|
||||
PatchArgumentParser._add_last_arg_parser(self)
|
||||
parsed_args = None
|
||||
parsed_args = parsed_args_str = None
|
||||
# parse if we are running in dev mode
|
||||
if not running_remotely() and original_parse_fn:
|
||||
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
|
||||
PatchArgumentParser._add_last_parsed_args(parsed_args)
|
||||
parsed_args_str = PatchArgumentParser._add_last_parsed_args(self, parsed_args)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -77,63 +83,88 @@ class PatchArgumentParser:
|
||||
# noinspection PyProtectedMember
|
||||
current_task._connect_argparse(
|
||||
self, args=args, namespace=namespace,
|
||||
parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args
|
||||
parsed_args=parsed_args_str[0] if isinstance(parsed_args_str, tuple) else parsed_args_str
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# sync back and parse
|
||||
if running_remotely() and original_parse_fn:
|
||||
# if we are running python2 check if we have subparsers,
|
||||
# if we do we need to patch the args, because there is no default subparser
|
||||
if PY2:
|
||||
import itertools
|
||||
if running_remotely():
|
||||
if original_parse_fn:
|
||||
# if we are running python2 check if we have subparsers,
|
||||
# if we do we need to patch the args, because there is no default subparser
|
||||
if PY2:
|
||||
import itertools
|
||||
|
||||
def _get_sub_parsers_defaults(subparser, prev=[]):
|
||||
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(
|
||||
subparser, _SubParsersAction) else [subparser._actions]
|
||||
sub_parsers_defaults = [[subparser]] if hasattr(
|
||||
subparser, 'default') and subparser.default else []
|
||||
for actions in actions_grp:
|
||||
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
|
||||
for a in actions if isinstance(a, _SubParsersAction) and
|
||||
hasattr(a, 'default') and a.default]
|
||||
def _get_sub_parsers_defaults(subparser, prev=[]):
|
||||
actions_grp = [v._actions for v in subparser.choices.values()] if isinstance(
|
||||
subparser, _SubParsersAction) else [subparser._actions]
|
||||
_sub_parsers_defaults = [[subparser]] if hasattr(
|
||||
subparser, 'default') and subparser.default else []
|
||||
for actions in actions_grp:
|
||||
_sub_parsers_defaults += [_get_sub_parsers_defaults(v, prev)
|
||||
for v in actions if isinstance(v, _SubParsersAction) and
|
||||
hasattr(v, 'default') and v.default]
|
||||
|
||||
return list(itertools.chain.from_iterable(sub_parsers_defaults))
|
||||
sub_parsers_defaults = _get_sub_parsers_defaults(self)
|
||||
if sub_parsers_defaults:
|
||||
if args is None:
|
||||
# args default to the system args
|
||||
import sys as _sys
|
||||
args = _sys.argv[1:]
|
||||
else:
|
||||
args = list(args)
|
||||
# make sure we append the subparsers
|
||||
for a in sub_parsers_defaults:
|
||||
if a.default not in args:
|
||||
args.append(a.default)
|
||||
return list(itertools.chain.from_iterable(_sub_parsers_defaults))
|
||||
|
||||
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace))
|
||||
else:
|
||||
PatchArgumentParser._add_last_parsed_args(parsed_args or {})
|
||||
sub_parsers_defaults = _get_sub_parsers_defaults(self)
|
||||
if sub_parsers_defaults:
|
||||
if args is None:
|
||||
# args default to the system args
|
||||
import sys as _sys
|
||||
args = _sys.argv[1:]
|
||||
else:
|
||||
args = list(args)
|
||||
# make sure we append the subparsers
|
||||
for a in sub_parsers_defaults:
|
||||
if a.default not in args:
|
||||
args.append(a.default)
|
||||
|
||||
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
|
||||
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
|
||||
else:
|
||||
# we should never get here
|
||||
parsed_args = parsed_args_str or {}
|
||||
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
|
||||
|
||||
PatchArgumentParser._calling_current_task = False
|
||||
return PatchArgumentParser._last_parsed_args[-1]
|
||||
return parsed_args
|
||||
|
||||
# Store last instance and result
|
||||
PatchArgumentParser._add_last_arg_parser(self)
|
||||
PatchArgumentParser._add_last_parsed_args(
|
||||
{} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace))
|
||||
return PatchArgumentParser._last_parsed_args[-1]
|
||||
parsed_args = {} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)
|
||||
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
|
||||
return parsed_args
|
||||
|
||||
@staticmethod
|
||||
def _add_last_parsed_args(parsed_args):
|
||||
def _add_last_parsed_args(parser, parsed_args):
|
||||
if hasattr(parser, '_parsed_arg_string_lookup'):
|
||||
if isinstance(parsed_args, tuple):
|
||||
parsed_args_namespace = copy(parsed_args[0])
|
||||
parsed_args = (parsed_args_namespace, parsed_args[1])
|
||||
else:
|
||||
parsed_args = parsed_args_namespace = copy(parsed_args)
|
||||
|
||||
if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace):
|
||||
for k, v in parser._parsed_arg_string_lookup.items(): # noqa
|
||||
if hasattr(parsed_args_namespace, k):
|
||||
setattr(parsed_args_namespace, k, v)
|
||||
|
||||
PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args]
|
||||
return parsed_args
|
||||
|
||||
@staticmethod
|
||||
def _add_last_arg_parser(a_argparser):
|
||||
PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser]
|
||||
|
||||
@staticmethod
|
||||
def _get_value(self, action, arg_string):
|
||||
if not hasattr(self, '_parsed_arg_string_lookup'):
|
||||
setattr(self, '_parsed_arg_string_lookup', dict())
|
||||
self._parsed_arg_string_lookup[str(action.dest)] = str(arg_string)
|
||||
return PatchArgumentParser._original_get_value(self, action, arg_string)
|
||||
|
||||
|
||||
def patch_argparse():
|
||||
# make sure we only patch once
|
||||
@ -148,6 +179,9 @@ def patch_argparse():
|
||||
sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args
|
||||
sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args
|
||||
sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers
|
||||
if hasattr(sys.modules['argparse'].ArgumentParser, '_get_value'):
|
||||
PatchArgumentParser._original_get_value = sys.modules['argparse'].ArgumentParser._get_value
|
||||
sys.modules['argparse'].ArgumentParser._get_value = PatchArgumentParser._get_value
|
||||
|
||||
|
||||
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
|
||||
@ -182,8 +216,8 @@ def add_params_to_parser(parser, params):
|
||||
def get_type_details(v):
|
||||
for t in (int, float, str):
|
||||
try:
|
||||
value = t(v)
|
||||
return t, value
|
||||
_value = t(v)
|
||||
return t, _value
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
@ -191,6 +225,6 @@ def add_params_to_parser(parser, params):
|
||||
params.pop('', None)
|
||||
|
||||
for param, value in params.items():
|
||||
type, type_value = get_type_details(value)
|
||||
parser.add_argument('--%s' % param, type=type, default=type_value)
|
||||
_type, type_value = get_type_details(value)
|
||||
parser.add_argument('--%s' % param, type=_type, default=type_value)
|
||||
return parser
|
||||
|
Loading…
Reference in New Issue
Block a user