""" Argparse utilities""" import sys from six import PY2 from argparse import ArgumentParser, _SubParsersAction class PatchArgumentParser: _original_parse_args = None _original_parse_known_args = None _original_add_subparsers = None _add_subparsers_counter = 0 _current_task = None _calling_current_task = False _last_parsed_args = None _last_arg_parser = None @staticmethod def add_subparsers(self, **kwargs): if 'dest' not in kwargs: if kwargs.get('title'): kwargs['dest'] = '/' + kwargs['title'] else: PatchArgumentParser._add_subparsers_counter += 1 kwargs['dest'] = '/subparser%d' % PatchArgumentParser._add_subparsers_counter return PatchArgumentParser._original_add_subparsers(self, **kwargs) @staticmethod def parse_args(self, args=None, namespace=None): return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args, self, args=args, namespace=namespace) @staticmethod def parse_known_args(self, args=None, namespace=None): return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args, self, args=args, namespace=namespace) @staticmethod def _patched_parse_args(original_parse_fn, self, args=None, namespace=None): # if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible. if not PatchArgumentParser._current_task: from ..config import running_remotely if running_remotely(): # this will cause the current_task() to set PatchArgumentParser._current_task from trains import Task # noinspection PyBroadException try: Task.init() except Exception: pass # automatically connect to current task: if PatchArgumentParser._current_task: from ..config import running_remotely if PatchArgumentParser._calling_current_task: # if we are here and running remotely by now we should try to parse the arguments if original_parse_fn: PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) return PatchArgumentParser._last_parsed_args[-1] PatchArgumentParser._calling_current_task = True # Store last instance and result PatchArgumentParser._add_last_arg_parser(self) parsed_args = None # parse if we are running in dev mode if not running_remotely() and original_parse_fn: parsed_args = original_parse_fn(self, args=args, namespace=namespace) PatchArgumentParser._add_last_parsed_args(parsed_args) # noinspection PyBroadException try: # sync to/from task # noinspection PyProtectedMember PatchArgumentParser._current_task._connect_argparse( self, args=args, namespace=namespace, parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args ) except Exception: pass # sync back and parse if running_remotely() and original_parse_fn: # if we are running python2 check if we have subparsers, # if we do we need to patch the args, because there is no default subparser if PY2: import itertools def _get_sub_parsers_defaults(subparser, prev=[]): actions_grp = [a._actions for a in subparser.choices.values()] if isinstance( subparser, _SubParsersAction) else [subparser._actions] sub_parsers_defaults = [[subparser]] if hasattr( subparser, 'default') and subparser.default else [] for actions in actions_grp: sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev) for a in actions if isinstance(a, _SubParsersAction) and hasattr(a, 'default') and a.default] return list(itertools.chain.from_iterable(sub_parsers_defaults)) sub_parsers_defaults = _get_sub_parsers_defaults(self) if sub_parsers_defaults: if args is None: # args default to the system args import sys as _sys args = _sys.argv[1:] else: args = list(args) # make sure we append the subparsers for a in sub_parsers_defaults: if a.default not in args: args.append(a.default) PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) else: PatchArgumentParser._add_last_parsed_args(parsed_args or {}) PatchArgumentParser._calling_current_task = False return PatchArgumentParser._last_parsed_args[-1] # Store last instance and result PatchArgumentParser._add_last_arg_parser(self) PatchArgumentParser._add_last_parsed_args( {} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)) return PatchArgumentParser._last_parsed_args[-1] @staticmethod def _add_last_parsed_args(parsed_args): PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args] @staticmethod def _add_last_arg_parser(a_argparser): PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser] def patch_argparse(): # make sure we only patch once if not sys.modules.get('argparse') or hasattr(sys.modules['argparse'].ArgumentParser, '_parse_args_patched'): return # mark patched argparse sys.modules['argparse'].ArgumentParser._parse_args_patched = True # patch argparser PatchArgumentParser._original_parse_args = sys.modules['argparse'].ArgumentParser.parse_args PatchArgumentParser._original_parse_known_args = sys.modules['argparse'].ArgumentParser.parse_known_args PatchArgumentParser._original_add_subparsers = sys.modules['argparse'].ArgumentParser.add_subparsers sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers # Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task patch_argparse() def call_original_argparser(self, args=None, namespace=None): if PatchArgumentParser._original_parse_args: return PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace) def argparser_parseargs_called(): return PatchArgumentParser._last_arg_parser is not None def argparser_update_currenttask(task): PatchArgumentParser._current_task = task def get_argparser_last_args(): if not PatchArgumentParser._last_arg_parser or not PatchArgumentParser._last_parsed_args: return [] return [(parser, args[0] if isinstance(args, tuple) else args) for parser, args in zip(PatchArgumentParser._last_arg_parser, PatchArgumentParser._last_parsed_args)] def add_params_to_parser(parser, params): assert isinstance(parser, ArgumentParser) assert isinstance(params, dict) def get_type_details(v): for t in (int, float, str): try: value = t(v) return t, value except ValueError: continue # AJB temporary protection from ui problems sending empty dicts params.pop('', None) for param, value in params.items(): type, type_value = get_type_details(value) parser.add_argument('--%s' % param, type=type, default=type_value) return parser