Fix argparse support to store consistent str representation of custom objects. Avoid changing default value if remote value matches.

Fix argsparse type as function
This commit is contained in:
allegroai 2021-01-10 12:53:15 +02:00
parent 3cb42acd12
commit 365c79326a
2 changed files with 100 additions and 51 deletions

View File

@ -1,9 +1,9 @@
import yaml
from inspect import isfunction
from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS # noqa
from copy import copy
import types
from ...backend_api import Session
from ...utilities.args import call_original_argparser
@ -95,9 +95,7 @@ class _Arguments(object):
else:
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
defaults_ = {
a.dest: args_dict.get(a.dest) if (
args_dict.get(a.dest) is not None and not callable(args_dict.get(a.dest))
) else '' for a in actions
a.dest: cls.__cast_arg(args_dict.get(a.dest)) for a in actions
}
except Exception:
# don't crash us if we failed parsing the inputs
@ -294,6 +292,11 @@ class _Arguments(object):
elif var_type == type(None): # noqa: E721 - do not change!
# because isinstance(var_type, type(None)) === False
var_type = str
elif var_type not in (bool, int, float, str, dict, list, tuple) and \
str(self.__cast_arg(current_action.default)) == str(v):
# there is nothing we can Do, just leave with the default
continue
# now we should try and cast the value if we can
# noinspection PyBroadException
try:
@ -327,7 +330,7 @@ class _Arguments(object):
arg_parser_arguments[k] = v
elif current_action and current_action.type:
# if we have an action type and value (v) is None, and cannot be casted, leave as is
if isinstance(current_action.type, types.FunctionType) and not v: # noqa
if isfunction(current_action.type) and not v: # noqa
# noinspection PyBroadException
try:
v = current_action.type(v)
@ -349,7 +352,10 @@ class _Arguments(object):
try:
if current_action.default is None and current_action.type != str and not v:
arg_parser_arguments[k] = v = None
elif current_action.default == current_action.type(v):
elif (not isfunction(current_action.type)
and current_action.default == current_action.type(v)) \
or (isfunction(current_action.type) and
str(self.__cast_arg(current_action.default)) == str(v)):
# this will make sure that if we have type float and default value int,
# we will keep the type as int, just like the original argparser
arg_parser_arguments[k] = v = current_action.default
@ -523,3 +529,12 @@ class _Arguments(object):
if not isinstance(dictionary, self._ProxyDictReadOnly):
return self._ProxyDictReadOnly(self, prefix, **dictionary)
return dictionary
@classmethod
def __cast_arg(cls, arg):
if arg is None or callable(arg):
return ''
# If this an instance, just store the type
if str(hex(id(arg))) in str(arg):
return str(type(arg))
return arg

View File

@ -1,7 +1,10 @@
""" Argparse utilities"""
import sys
from copy import copy
from six import PY2
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
try:
from argparse import _SubParsersAction
except ImportError:
@ -12,6 +15,7 @@ class PatchArgumentParser:
_original_parse_args = None
_original_parse_known_args = None
_original_add_subparsers = None
_original_get_value = None
_add_subparsers_counter = 0
_current_task = None
_calling_current_task = False
@ -58,18 +62,20 @@ class PatchArgumentParser:
if PatchArgumentParser._calling_current_task:
# if we are here and running remotely by now we should try to parse the arguments
parsed_args = None
if original_parse_fn:
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace))
return PatchArgumentParser._last_parsed_args[-1]
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
return parsed_args or PatchArgumentParser._last_parsed_args[-1]
PatchArgumentParser._calling_current_task = True
# Store last instance and result
PatchArgumentParser._add_last_arg_parser(self)
parsed_args = None
parsed_args = parsed_args_str = None
# parse if we are running in dev mode
if not running_remotely() and original_parse_fn:
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(parsed_args)
parsed_args_str = PatchArgumentParser._add_last_parsed_args(self, parsed_args)
# noinspection PyBroadException
try:
@ -77,63 +83,88 @@ class PatchArgumentParser:
# noinspection PyProtectedMember
current_task._connect_argparse(
self, args=args, namespace=namespace,
parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args
parsed_args=parsed_args_str[0] if isinstance(parsed_args_str, tuple) else parsed_args_str
)
except Exception:
pass
# sync back and parse
if running_remotely() and original_parse_fn:
# if we are running python2 check if we have subparsers,
# if we do we need to patch the args, because there is no default subparser
if PY2:
import itertools
if running_remotely():
if original_parse_fn:
# if we are running python2 check if we have subparsers,
# if we do we need to patch the args, because there is no default subparser
if PY2:
import itertools
def _get_sub_parsers_defaults(subparser, prev=[]):
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(
subparser, _SubParsersAction) else [subparser._actions]
sub_parsers_defaults = [[subparser]] if hasattr(
subparser, 'default') and subparser.default else []
for actions in actions_grp:
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
for a in actions if isinstance(a, _SubParsersAction) and
hasattr(a, 'default') and a.default]
def _get_sub_parsers_defaults(subparser, prev=[]):
actions_grp = [v._actions for v in subparser.choices.values()] if isinstance(
subparser, _SubParsersAction) else [subparser._actions]
_sub_parsers_defaults = [[subparser]] if hasattr(
subparser, 'default') and subparser.default else []
for actions in actions_grp:
_sub_parsers_defaults += [_get_sub_parsers_defaults(v, prev)
for v in actions if isinstance(v, _SubParsersAction) and
hasattr(v, 'default') and v.default]
return list(itertools.chain.from_iterable(sub_parsers_defaults))
sub_parsers_defaults = _get_sub_parsers_defaults(self)
if sub_parsers_defaults:
if args is None:
# args default to the system args
import sys as _sys
args = _sys.argv[1:]
else:
args = list(args)
# make sure we append the subparsers
for a in sub_parsers_defaults:
if a.default not in args:
args.append(a.default)
return list(itertools.chain.from_iterable(_sub_parsers_defaults))
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace))
else:
PatchArgumentParser._add_last_parsed_args(parsed_args or {})
sub_parsers_defaults = _get_sub_parsers_defaults(self)
if sub_parsers_defaults:
if args is None:
# args default to the system args
import sys as _sys
args = _sys.argv[1:]
else:
args = list(args)
# make sure we append the subparsers
for a in sub_parsers_defaults:
if a.default not in args:
args.append(a.default)
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
else:
# we should never get here
parsed_args = parsed_args_str or {}
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
PatchArgumentParser._calling_current_task = False
return PatchArgumentParser._last_parsed_args[-1]
return parsed_args
# Store last instance and result
PatchArgumentParser._add_last_arg_parser(self)
PatchArgumentParser._add_last_parsed_args(
{} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace))
return PatchArgumentParser._last_parsed_args[-1]
parsed_args = {} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
return parsed_args
@staticmethod
def _add_last_parsed_args(parsed_args):
def _add_last_parsed_args(parser, parsed_args):
if hasattr(parser, '_parsed_arg_string_lookup'):
if isinstance(parsed_args, tuple):
parsed_args_namespace = copy(parsed_args[0])
parsed_args = (parsed_args_namespace, parsed_args[1])
else:
parsed_args = parsed_args_namespace = copy(parsed_args)
if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace):
for k, v in parser._parsed_arg_string_lookup.items(): # noqa
if hasattr(parsed_args_namespace, k):
setattr(parsed_args_namespace, k, v)
PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args]
return parsed_args
@staticmethod
def _add_last_arg_parser(a_argparser):
PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser]
@staticmethod
def _get_value(self, action, arg_string):
if not hasattr(self, '_parsed_arg_string_lookup'):
setattr(self, '_parsed_arg_string_lookup', dict())
self._parsed_arg_string_lookup[str(action.dest)] = str(arg_string)
return PatchArgumentParser._original_get_value(self, action, arg_string)
def patch_argparse():
# make sure we only patch once
@ -148,6 +179,9 @@ def patch_argparse():
sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args
sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args
sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers
if hasattr(sys.modules['argparse'].ArgumentParser, '_get_value'):
PatchArgumentParser._original_get_value = sys.modules['argparse'].ArgumentParser._get_value
sys.modules['argparse'].ArgumentParser._get_value = PatchArgumentParser._get_value
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
@ -182,8 +216,8 @@ def add_params_to_parser(parser, params):
def get_type_details(v):
for t in (int, float, str):
try:
value = t(v)
return t, value
_value = t(v)
return t, _value
except ValueError:
continue
@ -191,6 +225,6 @@ def add_params_to_parser(parser, params):
params.pop('', None)
for param, value in params.items():
type, type_value = get_type_details(value)
parser.add_argument('--%s' % param, type=type, default=type_value)
_type, type_value = get_type_details(value)
parser.add_argument('--%s' % param, type=_type, default=type_value)
return parser