clearml/trains/utilities/args.py
2019-06-10 20:02:11 +03:00

177 lines
7.9 KiB
Python

""" Argparse utilities"""
import sys
from six import PY2
from argparse import ArgumentParser, _SubParsersAction
class PatchArgumentParser:
_original_parse_args = None
_original_parse_known_args = None
_original_add_subparsers = None
_add_subparsers_counter = 0
_current_task = None
_calling_current_task = False
_last_parsed_args = None
_last_arg_parser = None
@staticmethod
def add_subparsers(self, **kwargs):
if 'dest' not in kwargs:
if kwargs.get('title'):
kwargs['dest'] = '/'+kwargs['title']
else:
PatchArgumentParser._add_subparsers_counter += 1
kwargs['dest'] = '/subparser%d' % PatchArgumentParser._add_subparsers_counter
return PatchArgumentParser._original_add_subparsers(self, **kwargs)
@staticmethod
def parse_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args,
self, args=args, namespace=namespace)
@staticmethod
def parse_known_args(self, args=None, namespace=None):
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args,
self, args=args, namespace=namespace)
@staticmethod
def _patched_parse_args(original_parse_fn, self, args=None, namespace=None):
# if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible.
if not PatchArgumentParser._current_task:
from ..config import running_remotely
if running_remotely():
# this will cause the current_task() to set PatchArgumentParser._current_task
from trains import Task
# noinspection PyBroadException
try:
Task.init()
except Exception:
pass
# automatically connect to current task:
if PatchArgumentParser._current_task:
from ..config import running_remotely
if PatchArgumentParser._calling_current_task:
# if we are here and running remotely by now we should try to parse the arguments
if original_parse_fn:
PatchArgumentParser._last_parsed_args = \
original_parse_fn(self, args=args, namespace=namespace)
return PatchArgumentParser._last_parsed_args
PatchArgumentParser._calling_current_task = True
# Store last instance and result
PatchArgumentParser._last_arg_parser = self
parsed_args = None
# parse if we are running in dev mode
if not running_remotely() and original_parse_fn:
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
PatchArgumentParser._last_parsed_args = parsed_args
# noinspection PyBroadException
try:
# sync to/from task
PatchArgumentParser._current_task._connect_argparse(self, args=args, namespace=namespace,
parsed_args=parsed_args[0]
if isinstance(parsed_args, tuple) else parsed_args)
except Exception:
pass
# sync back and parse
if running_remotely() and original_parse_fn:
# if we are running python2 check if we have subparsers,
# if we do we need to patch the args, because there is no default subparser
if PY2:
import itertools
def _get_sub_parsers_defaults(subparser, prev=[]):
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(subparser, _SubParsersAction) else \
[subparser._actions]
sub_parsers_defaults = [[subparser]] if hasattr(subparser, 'default') and subparser.default else []
for actions in actions_grp:
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
for a in actions if isinstance(a, _SubParsersAction) and
hasattr(a, 'default') and a.default]
return list(itertools.chain.from_iterable(sub_parsers_defaults))
sub_parsers_defaults = _get_sub_parsers_defaults(self)
if sub_parsers_defaults:
if args is None:
# args default to the system args
import sys as _sys
args = _sys.argv[1:]
else:
args = list(args)
# make sure we append the subparsers
for a in sub_parsers_defaults:
if a.default not in args:
args.append(a.default)
PatchArgumentParser._last_parsed_args = original_parse_fn(self, args=args, namespace=namespace)
else:
PatchArgumentParser._last_parsed_args = parsed_args or {}
PatchArgumentParser._calling_current_task = False
return PatchArgumentParser._last_parsed_args
# Store last instance and result
PatchArgumentParser._last_arg_parser = self
PatchArgumentParser._last_parsed_args = {} if not original_parse_fn else \
original_parse_fn(self, args=args, namespace=namespace)
return PatchArgumentParser._last_parsed_args
def patch_argparse():
# make sure we only patch once
if not sys.modules.get('argparse') or hasattr(sys.modules['argparse'].ArgumentParser, '_parse_args_patched'):
return
# mark patched argparse
sys.modules['argparse'].ArgumentParser._parse_args_patched = True
# patch argparser
PatchArgumentParser._original_parse_args = sys.modules['argparse'].ArgumentParser.parse_args
PatchArgumentParser._original_parse_known_args = sys.modules['argparse'].ArgumentParser.parse_known_args
PatchArgumentParser._original_add_subparsers = sys.modules['argparse'].ArgumentParser.add_subparsers
sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args
sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args
sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
patch_argparse()
def call_original_argparser(self, args=None, namespace=None):
if PatchArgumentParser._original_parse_args:
return PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace)
def argparser_parseargs_called():
return PatchArgumentParser._last_arg_parser is not None
def argparser_update_currenttask(task):
PatchArgumentParser._current_task = task
def get_argparser_last_args():
return (PatchArgumentParser._last_arg_parser,
PatchArgumentParser._last_parsed_args[0] if isinstance(PatchArgumentParser._last_parsed_args, tuple) else
PatchArgumentParser._last_parsed_args)
def add_params_to_parser(parser, params):
assert isinstance(parser, ArgumentParser)
assert isinstance(params, dict)
def get_type_details(v):
for t in (int, float, str):
try:
value = t(v)
return t, value
except ValueError:
continue
# AJB temporary protection from ui problems sending empty dicts
params.pop('', None)
for param, value in params.items():
type, type_value = get_type_details(value)
parser.add_argument('--%s' % param, type=type, default=type_value)
return parser