Task.init argument auto_connect_arg_parser can accept a dictionary disabling specific keys from the argparser

This commit is contained in:
allegroai 2020-05-13 20:42:33 +03:00
parent d2c9523769
commit cb8887da72
2 changed files with 28 additions and 7 deletions

View File

@ -40,6 +40,10 @@ class _Arguments(object):
def __init__(self, task):
super(_Arguments, self).__init__()
self._task = task
self._exclude_parser_args = {}
def exclude_parser_args(self, excluded_args):
self._exclude_parser_args = excluded_args or {}
def set_defaults(self, *dicts, **kwargs):
self._task.set_parameters(*dicts, **kwargs)
@ -126,9 +130,11 @@ class _Arguments(object):
task_defaults[k] = str(v)
except Exception:
del task_defaults[k]
# Add prefix, TODO: add argparse prefix
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()])
task_defaults = dict([(k, v) for k, v in task_defaults.items()])
# Skip excluded arguments, Add prefix, TODO: add argparse prefix
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()
# if k not in self._exclude_parser_args])
task_defaults = dict([(k, v) for k, v in task_defaults.items() if self._exclude_parser_args.get(k, True)])
# Store to task
self._task.update_parameters(task_defaults)
@ -154,7 +160,7 @@ class _Arguments(object):
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
# if k.startswith(self._prefix_args)])
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
if not k.startswith(self._prefix_tf_defines)])
if not k.startswith(self._prefix_tf_defines) and self._exclude_parser_args.get(k, True)])
arg_parser_argeuments = {}
for k, v in task_arguments.items():
# python2 unicode support

View File

@ -169,7 +169,7 @@ class Task(_Task):
task_type=TaskTypes.training, # type: Task.TaskTypes
reuse_last_task_id=True, # type: bool
output_uri=None, # type: Optional[str]
auto_connect_arg_parser=True, # type: bool
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, # type: bool
):
@ -236,12 +236,22 @@ class Task(_Task):
`Trains Python Client Extras <./references/trains_extras_storage/>`_ in the "Trains Python Client
Reference" section.
:param bool auto_connect_arg_parser: Automatically connect an argparse object to the Task?
:param auto_connect_arg_parser: Automatically connect an argparse object to the Task?
The values are:
- ``True`` - Automatically connect. (Default)
- ``False`` - Do not automatically connect.
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
arguments. The dictionary keys are argparse variable names and the values are booleans,
False value will exclude the specified argument from the Task's parameter section.
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
For example:
.. code-block:: py
auto_connect_arg_parser={'do_not_include_me': False, }
.. note::
To manually connect an argparse, use :meth:`Task.connect`.
@ -256,6 +266,7 @@ class Task(_Task):
- ``False`` - Do not automatically connect
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
frameworks. The dictionary keys are frameworks and the values are booleans.
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
For example:
@ -264,7 +275,6 @@ class Task(_Task):
auto_connect_frameworks={'matplotlib': True, 'tensorflow': True, 'pytorch': True,
'xgboost': True, 'scikit': True}
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
:type auto_connect_frameworks: bool or dict
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots? These plots appear in
in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, with a title of **:resource monitor:**.
@ -440,6 +450,11 @@ class Task(_Task):
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
# set excluded arguments
if isinstance(auto_connect_arg_parser, dict):
task._arguments.exclude_parser_args(auto_connect_arg_parser)
# Check if parse args already called. If so, sync task parameters with parser
if argparser_parseargs_called():
parser, parsed_args = get_argparser_last_args()