Fix argparse support to store consistent str representation of custom objects. Avoid changing default value if remote value matches.

Fix argsparse type as function
This commit is contained in:
allegroai 2021-01-10 12:53:15 +02:00
parent 3cb42acd12
commit 365c79326a
2 changed files with 100 additions and 51 deletions

View File

@ -1,9 +1,9 @@
import yaml import yaml
from inspect import isfunction
from six import PY2 from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS # noqa from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS # noqa
from copy import copy from copy import copy
import types
from ...backend_api import Session from ...backend_api import Session
from ...utilities.args import call_original_argparser from ...utilities.args import call_original_argparser
@ -95,9 +95,7 @@ class _Arguments(object):
else: else:
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__ args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
defaults_ = { defaults_ = {
a.dest: args_dict.get(a.dest) if ( a.dest: cls.__cast_arg(args_dict.get(a.dest)) for a in actions
args_dict.get(a.dest) is not None and not callable(args_dict.get(a.dest))
) else '' for a in actions
} }
except Exception: except Exception:
# don't crash us if we failed parsing the inputs # don't crash us if we failed parsing the inputs
@ -294,6 +292,11 @@ class _Arguments(object):
elif var_type == type(None): # noqa: E721 - do not change! elif var_type == type(None): # noqa: E721 - do not change!
# because isinstance(var_type, type(None)) === False # because isinstance(var_type, type(None)) === False
var_type = str var_type = str
elif var_type not in (bool, int, float, str, dict, list, tuple) and \
str(self.__cast_arg(current_action.default)) == str(v):
# there is nothing we can Do, just leave with the default
continue
# now we should try and cast the value if we can # now we should try and cast the value if we can
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -327,7 +330,7 @@ class _Arguments(object):
arg_parser_arguments[k] = v arg_parser_arguments[k] = v
elif current_action and current_action.type: elif current_action and current_action.type:
# if we have an action type and value (v) is None, and cannot be casted, leave as is # if we have an action type and value (v) is None, and cannot be casted, leave as is
if isinstance(current_action.type, types.FunctionType) and not v: # noqa if isfunction(current_action.type) and not v: # noqa
# noinspection PyBroadException # noinspection PyBroadException
try: try:
v = current_action.type(v) v = current_action.type(v)
@ -349,7 +352,10 @@ class _Arguments(object):
try: try:
if current_action.default is None and current_action.type != str and not v: if current_action.default is None and current_action.type != str and not v:
arg_parser_arguments[k] = v = None arg_parser_arguments[k] = v = None
elif current_action.default == current_action.type(v): elif (not isfunction(current_action.type)
and current_action.default == current_action.type(v)) \
or (isfunction(current_action.type) and
str(self.__cast_arg(current_action.default)) == str(v)):
# this will make sure that if we have type float and default value int, # this will make sure that if we have type float and default value int,
# we will keep the type as int, just like the original argparser # we will keep the type as int, just like the original argparser
arg_parser_arguments[k] = v = current_action.default arg_parser_arguments[k] = v = current_action.default
@ -523,3 +529,12 @@ class _Arguments(object):
if not isinstance(dictionary, self._ProxyDictReadOnly): if not isinstance(dictionary, self._ProxyDictReadOnly):
return self._ProxyDictReadOnly(self, prefix, **dictionary) return self._ProxyDictReadOnly(self, prefix, **dictionary)
return dictionary return dictionary
@classmethod
def __cast_arg(cls, arg):
if arg is None or callable(arg):
return ''
# If this an instance, just store the type
if str(hex(id(arg))) in str(arg):
return str(type(arg))
return arg

View File

@ -1,7 +1,10 @@
""" Argparse utilities""" """ Argparse utilities"""
import sys import sys
from copy import copy
from six import PY2 from six import PY2
from argparse import ArgumentParser from argparse import ArgumentParser, Namespace
try: try:
from argparse import _SubParsersAction from argparse import _SubParsersAction
except ImportError: except ImportError:
@ -12,6 +15,7 @@ class PatchArgumentParser:
_original_parse_args = None _original_parse_args = None
_original_parse_known_args = None _original_parse_known_args = None
_original_add_subparsers = None _original_add_subparsers = None
_original_get_value = None
_add_subparsers_counter = 0 _add_subparsers_counter = 0
_current_task = None _current_task = None
_calling_current_task = False _calling_current_task = False
@ -58,18 +62,20 @@ class PatchArgumentParser:
if PatchArgumentParser._calling_current_task: if PatchArgumentParser._calling_current_task:
# if we are here and running remotely by now we should try to parse the arguments # if we are here and running remotely by now we should try to parse the arguments
parsed_args = None
if original_parse_fn: if original_parse_fn:
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) parsed_args = original_parse_fn(self, args=args, namespace=namespace)
return PatchArgumentParser._last_parsed_args[-1] PatchArgumentParser._add_last_parsed_args(self, parsed_args)
return parsed_args or PatchArgumentParser._last_parsed_args[-1]
PatchArgumentParser._calling_current_task = True PatchArgumentParser._calling_current_task = True
# Store last instance and result # Store last instance and result
PatchArgumentParser._add_last_arg_parser(self) PatchArgumentParser._add_last_arg_parser(self)
parsed_args = None parsed_args = parsed_args_str = None
# parse if we are running in dev mode # parse if we are running in dev mode
if not running_remotely() and original_parse_fn: if not running_remotely() and original_parse_fn:
parsed_args = original_parse_fn(self, args=args, namespace=namespace) parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(parsed_args) parsed_args_str = PatchArgumentParser._add_last_parsed_args(self, parsed_args)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -77,63 +83,88 @@ class PatchArgumentParser:
# noinspection PyProtectedMember # noinspection PyProtectedMember
current_task._connect_argparse( current_task._connect_argparse(
self, args=args, namespace=namespace, self, args=args, namespace=namespace,
parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args parsed_args=parsed_args_str[0] if isinstance(parsed_args_str, tuple) else parsed_args_str
) )
except Exception: except Exception:
pass pass
# sync back and parse # sync back and parse
if running_remotely() and original_parse_fn: if running_remotely():
# if we are running python2 check if we have subparsers, if original_parse_fn:
# if we do we need to patch the args, because there is no default subparser # if we are running python2 check if we have subparsers,
if PY2: # if we do we need to patch the args, because there is no default subparser
import itertools if PY2:
import itertools
def _get_sub_parsers_defaults(subparser, prev=[]): def _get_sub_parsers_defaults(subparser, prev=[]):
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance( actions_grp = [v._actions for v in subparser.choices.values()] if isinstance(
subparser, _SubParsersAction) else [subparser._actions] subparser, _SubParsersAction) else [subparser._actions]
sub_parsers_defaults = [[subparser]] if hasattr( _sub_parsers_defaults = [[subparser]] if hasattr(
subparser, 'default') and subparser.default else [] subparser, 'default') and subparser.default else []
for actions in actions_grp: for actions in actions_grp:
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev) _sub_parsers_defaults += [_get_sub_parsers_defaults(v, prev)
for a in actions if isinstance(a, _SubParsersAction) and for v in actions if isinstance(v, _SubParsersAction) and
hasattr(a, 'default') and a.default] hasattr(v, 'default') and v.default]
return list(itertools.chain.from_iterable(sub_parsers_defaults)) return list(itertools.chain.from_iterable(_sub_parsers_defaults))
sub_parsers_defaults = _get_sub_parsers_defaults(self)
if sub_parsers_defaults:
if args is None:
# args default to the system args
import sys as _sys
args = _sys.argv[1:]
else:
args = list(args)
# make sure we append the subparsers
for a in sub_parsers_defaults:
if a.default not in args:
args.append(a.default)
PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) sub_parsers_defaults = _get_sub_parsers_defaults(self)
else: if sub_parsers_defaults:
PatchArgumentParser._add_last_parsed_args(parsed_args or {}) if args is None:
# args default to the system args
import sys as _sys
args = _sys.argv[1:]
else:
args = list(args)
# make sure we append the subparsers
for a in sub_parsers_defaults:
if a.default not in args:
args.append(a.default)
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
else:
# we should never get here
parsed_args = parsed_args_str or {}
PatchArgumentParser._add_last_parsed_args(self, parsed_args)
PatchArgumentParser._calling_current_task = False PatchArgumentParser._calling_current_task = False
return PatchArgumentParser._last_parsed_args[-1] return parsed_args
# Store last instance and result # Store last instance and result
PatchArgumentParser._add_last_arg_parser(self) PatchArgumentParser._add_last_arg_parser(self)
PatchArgumentParser._add_last_parsed_args( parsed_args = {} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)
{} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)) PatchArgumentParser._add_last_parsed_args(self, parsed_args)
return PatchArgumentParser._last_parsed_args[-1] return parsed_args
@staticmethod @staticmethod
def _add_last_parsed_args(parsed_args): def _add_last_parsed_args(parser, parsed_args):
if hasattr(parser, '_parsed_arg_string_lookup'):
if isinstance(parsed_args, tuple):
parsed_args_namespace = copy(parsed_args[0])
parsed_args = (parsed_args_namespace, parsed_args[1])
else:
parsed_args = parsed_args_namespace = copy(parsed_args)
if parsed_args_namespace and isinstance(parsed_args_namespace, Namespace):
for k, v in parser._parsed_arg_string_lookup.items(): # noqa
if hasattr(parsed_args_namespace, k):
setattr(parsed_args_namespace, k, v)
PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args] PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args]
return parsed_args
@staticmethod @staticmethod
def _add_last_arg_parser(a_argparser): def _add_last_arg_parser(a_argparser):
PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser] PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser]
@staticmethod
def _get_value(self, action, arg_string):
if not hasattr(self, '_parsed_arg_string_lookup'):
setattr(self, '_parsed_arg_string_lookup', dict())
self._parsed_arg_string_lookup[str(action.dest)] = str(arg_string)
return PatchArgumentParser._original_get_value(self, action, arg_string)
def patch_argparse(): def patch_argparse():
# make sure we only patch once # make sure we only patch once
@ -148,6 +179,9 @@ def patch_argparse():
sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args
sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args
sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers
if hasattr(sys.modules['argparse'].ArgumentParser, '_get_value'):
PatchArgumentParser._original_get_value = sys.modules['argparse'].ArgumentParser._get_value
sys.modules['argparse'].ArgumentParser._get_value = PatchArgumentParser._get_value
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task # Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
@ -182,8 +216,8 @@ def add_params_to_parser(parser, params):
def get_type_details(v): def get_type_details(v):
for t in (int, float, str): for t in (int, float, str):
try: try:
value = t(v) _value = t(v)
return t, value return t, _value
except ValueError: except ValueError:
continue continue
@ -191,6 +225,6 @@ def add_params_to_parser(parser, params):
params.pop('', None) params.pop('', None)
for param, value in params.items(): for param, value in params.items():
type, type_value = get_type_details(value) _type, type_value = get_type_details(value)
parser.add_argument('--%s' % param, type=type, default=type_value) parser.add_argument('--%s' % param, type=_type, default=type_value)
return parser return parser