mirror of
https://github.com/clearml/clearml
synced 2025-01-31 17:17:00 +00:00
177 lines
7.9 KiB
Python
177 lines
7.9 KiB
Python
""" Argparse utilities"""
|
|
import sys
|
|
from six import PY2
|
|
from argparse import ArgumentParser, _SubParsersAction
|
|
|
|
|
|
class PatchArgumentParser:
|
|
_original_parse_args = None
|
|
_original_parse_known_args = None
|
|
_original_add_subparsers = None
|
|
_add_subparsers_counter = 0
|
|
_current_task = None
|
|
_calling_current_task = False
|
|
_last_parsed_args = None
|
|
_last_arg_parser = None
|
|
|
|
@staticmethod
|
|
def add_subparsers(self, **kwargs):
|
|
if 'dest' not in kwargs:
|
|
if kwargs.get('title'):
|
|
kwargs['dest'] = '/'+kwargs['title']
|
|
else:
|
|
PatchArgumentParser._add_subparsers_counter += 1
|
|
kwargs['dest'] = '/subparser%d' % PatchArgumentParser._add_subparsers_counter
|
|
return PatchArgumentParser._original_add_subparsers(self, **kwargs)
|
|
|
|
@staticmethod
|
|
def parse_args(self, args=None, namespace=None):
|
|
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_args,
|
|
self, args=args, namespace=namespace)
|
|
|
|
@staticmethod
|
|
def parse_known_args(self, args=None, namespace=None):
|
|
return PatchArgumentParser._patched_parse_args(PatchArgumentParser._original_parse_known_args,
|
|
self, args=args, namespace=namespace)
|
|
|
|
@staticmethod
|
|
def _patched_parse_args(original_parse_fn, self, args=None, namespace=None):
|
|
# if we are running remotely, we always have a task id, so we better patch the argparser as soon as possible.
|
|
if not PatchArgumentParser._current_task:
|
|
from ..config import running_remotely
|
|
if running_remotely():
|
|
# this will cause the current_task() to set PatchArgumentParser._current_task
|
|
from trains import Task
|
|
# noinspection PyBroadException
|
|
try:
|
|
Task.init()
|
|
except Exception:
|
|
pass
|
|
# automatically connect to current task:
|
|
if PatchArgumentParser._current_task:
|
|
from ..config import running_remotely
|
|
|
|
if PatchArgumentParser._calling_current_task:
|
|
# if we are here and running remotely by now we should try to parse the arguments
|
|
if original_parse_fn:
|
|
PatchArgumentParser._last_parsed_args = \
|
|
original_parse_fn(self, args=args, namespace=namespace)
|
|
return PatchArgumentParser._last_parsed_args
|
|
|
|
PatchArgumentParser._calling_current_task = True
|
|
# Store last instance and result
|
|
PatchArgumentParser._last_arg_parser = self
|
|
parsed_args = None
|
|
# parse if we are running in dev mode
|
|
if not running_remotely() and original_parse_fn:
|
|
parsed_args = original_parse_fn(self, args=args, namespace=namespace)
|
|
PatchArgumentParser._last_parsed_args = parsed_args
|
|
# noinspection PyBroadException
|
|
try:
|
|
# sync to/from task
|
|
PatchArgumentParser._current_task._connect_argparse(self, args=args, namespace=namespace,
|
|
parsed_args=parsed_args[0]
|
|
if isinstance(parsed_args, tuple) else parsed_args)
|
|
except Exception:
|
|
pass
|
|
# sync back and parse
|
|
if running_remotely() and original_parse_fn:
|
|
# if we are running python2 check if we have subparsers,
|
|
# if we do we need to patch the args, because there is no default subparser
|
|
if PY2:
|
|
import itertools
|
|
def _get_sub_parsers_defaults(subparser, prev=[]):
|
|
actions_grp = [a._actions for a in subparser.choices.values()] if isinstance(subparser, _SubParsersAction) else \
|
|
[subparser._actions]
|
|
sub_parsers_defaults = [[subparser]] if hasattr(subparser, 'default') and subparser.default else []
|
|
for actions in actions_grp:
|
|
sub_parsers_defaults += [_get_sub_parsers_defaults(a, prev)
|
|
for a in actions if isinstance(a, _SubParsersAction) and
|
|
hasattr(a, 'default') and a.default]
|
|
|
|
return list(itertools.chain.from_iterable(sub_parsers_defaults))
|
|
sub_parsers_defaults = _get_sub_parsers_defaults(self)
|
|
if sub_parsers_defaults:
|
|
if args is None:
|
|
# args default to the system args
|
|
import sys as _sys
|
|
args = _sys.argv[1:]
|
|
else:
|
|
args = list(args)
|
|
# make sure we append the subparsers
|
|
for a in sub_parsers_defaults:
|
|
if a.default not in args:
|
|
args.append(a.default)
|
|
|
|
PatchArgumentParser._last_parsed_args = original_parse_fn(self, args=args, namespace=namespace)
|
|
else:
|
|
PatchArgumentParser._last_parsed_args = parsed_args or {}
|
|
|
|
PatchArgumentParser._calling_current_task = False
|
|
return PatchArgumentParser._last_parsed_args
|
|
|
|
# Store last instance and result
|
|
PatchArgumentParser._last_arg_parser = self
|
|
PatchArgumentParser._last_parsed_args = {} if not original_parse_fn else \
|
|
original_parse_fn(self, args=args, namespace=namespace)
|
|
return PatchArgumentParser._last_parsed_args
|
|
|
|
|
|
def patch_argparse():
|
|
# make sure we only patch once
|
|
if not sys.modules.get('argparse') or hasattr(sys.modules['argparse'].ArgumentParser, '_parse_args_patched'):
|
|
return
|
|
# mark patched argparse
|
|
sys.modules['argparse'].ArgumentParser._parse_args_patched = True
|
|
# patch argparser
|
|
PatchArgumentParser._original_parse_args = sys.modules['argparse'].ArgumentParser.parse_args
|
|
PatchArgumentParser._original_parse_known_args = sys.modules['argparse'].ArgumentParser.parse_known_args
|
|
PatchArgumentParser._original_add_subparsers = sys.modules['argparse'].ArgumentParser.add_subparsers
|
|
sys.modules['argparse'].ArgumentParser.parse_args = PatchArgumentParser.parse_args
|
|
sys.modules['argparse'].ArgumentParser.parse_known_args = PatchArgumentParser.parse_known_args
|
|
sys.modules['argparse'].ArgumentParser.add_subparsers = PatchArgumentParser.add_subparsers
|
|
|
|
|
|
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
|
|
patch_argparse()
|
|
|
|
|
|
def call_original_argparser(self, args=None, namespace=None):
|
|
if PatchArgumentParser._original_parse_args:
|
|
return PatchArgumentParser._original_parse_args(self, args=args, namespace=namespace)
|
|
|
|
|
|
def argparser_parseargs_called():
|
|
return PatchArgumentParser._last_arg_parser is not None
|
|
|
|
|
|
def argparser_update_currenttask(task):
|
|
PatchArgumentParser._current_task = task
|
|
|
|
|
|
def get_argparser_last_args():
|
|
return (PatchArgumentParser._last_arg_parser,
|
|
PatchArgumentParser._last_parsed_args[0] if isinstance(PatchArgumentParser._last_parsed_args, tuple) else
|
|
PatchArgumentParser._last_parsed_args)
|
|
|
|
|
|
def add_params_to_parser(parser, params):
|
|
assert isinstance(parser, ArgumentParser)
|
|
assert isinstance(params, dict)
|
|
|
|
def get_type_details(v):
|
|
for t in (int, float, str):
|
|
try:
|
|
value = t(v)
|
|
return t, value
|
|
except ValueError:
|
|
continue
|
|
|
|
# AJB temporary protection from ui problems sending empty dicts
|
|
params.pop('', None)
|
|
|
|
for param, value in params.items():
|
|
type, type_value = get_type_details(value)
|
|
parser.add_argument('--%s' % param, type=type, default=type_value)
|
|
return parser
|