diff --git a/trains/utilities/args.py b/trains/utilities/args.py index 7258b384..0649d220 100644 --- a/trains/utilities/args.py +++ b/trains/utilities/args.py @@ -54,26 +54,29 @@ class PatchArgumentParser: if PatchArgumentParser._calling_current_task: # if we are here and running remotely by now we should try to parse the arguments if original_parse_fn: - PatchArgumentParser._last_parsed_args = \ - original_parse_fn(self, args=args, namespace=namespace) - return PatchArgumentParser._last_parsed_args + PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) + return PatchArgumentParser._last_parsed_args[-1] PatchArgumentParser._calling_current_task = True # Store last instance and result - PatchArgumentParser._last_arg_parser = self + PatchArgumentParser._add_last_arg_parser(self) parsed_args = None # parse if we are running in dev mode if not running_remotely() and original_parse_fn: parsed_args = original_parse_fn(self, args=args, namespace=namespace) - PatchArgumentParser._last_parsed_args = parsed_args + PatchArgumentParser._add_last_parsed_args(parsed_args) + # noinspection PyBroadException try: # sync to/from task - PatchArgumentParser._current_task._connect_argparse(self, args=args, namespace=namespace, - parsed_args=parsed_args[0] - if isinstance(parsed_args, tuple) else parsed_args) + # noinspection PyProtectedMember + PatchArgumentParser._current_task._connect_argparse( + self, args=args, namespace=namespace, + parsed_args=parsed_args[0] if isinstance(parsed_args, tuple) else parsed_args + ) except Exception: pass + # sync back and parse if running_remotely() and original_parse_fn: # if we are running python2 check if we have subparsers, @@ -105,18 +108,26 @@ class PatchArgumentParser: if a.default not in args: args.append(a.default) - PatchArgumentParser._last_parsed_args = original_parse_fn(self, args=args, namespace=namespace) + PatchArgumentParser._add_last_parsed_args(original_parse_fn(self, args=args, namespace=namespace)) else: - PatchArgumentParser._last_parsed_args = parsed_args or {} + PatchArgumentParser._add_last_parsed_args(parsed_args or {}) PatchArgumentParser._calling_current_task = False - return PatchArgumentParser._last_parsed_args + return PatchArgumentParser._last_parsed_args[-1] # Store last instance and result - PatchArgumentParser._last_arg_parser = self - PatchArgumentParser._last_parsed_args = {} if not original_parse_fn else \ - original_parse_fn(self, args=args, namespace=namespace) - return PatchArgumentParser._last_parsed_args + PatchArgumentParser._add_last_arg_parser(self) + PatchArgumentParser._add_last_parsed_args( + {} if not original_parse_fn else original_parse_fn(self, args=args, namespace=namespace)) + return PatchArgumentParser._last_parsed_args[-1] + + @staticmethod + def _add_last_parsed_args(parsed_args): + PatchArgumentParser._last_parsed_args = (PatchArgumentParser._last_parsed_args or []) + [parsed_args] + + @staticmethod + def _add_last_arg_parser(a_argparser): + PatchArgumentParser._last_arg_parser = (PatchArgumentParser._last_arg_parser or []) + [a_argparser] def patch_argparse(): @@ -152,9 +163,11 @@ def argparser_update_currenttask(task): def get_argparser_last_args(): - return (PatchArgumentParser._last_arg_parser, - PatchArgumentParser._last_parsed_args[0] if isinstance(PatchArgumentParser._last_parsed_args, tuple) else - PatchArgumentParser._last_parsed_args) + if not PatchArgumentParser._last_arg_parser or not PatchArgumentParser._last_parsed_args: + return [] + + return [(parser, args[0] if isinstance(args, tuple) else args) + for parser, args in zip(PatchArgumentParser._last_arg_parser, PatchArgumentParser._last_parsed_args)] def add_params_to_parser(parser, params):