Add Click support (issue #386)

This commit is contained in:
allegroai 2021-07-04 09:31:47 +03:00
parent c42e32e137
commit 226c68330e
7 changed files with 234 additions and 4 deletions

View File

@ -6,7 +6,7 @@ from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsers
from copy import copy
from ...backend_api import Session
from ...utilities.args import call_original_argparser
from ...binding.args import call_original_argparser
class _Arguments(object):

View File

@ -223,7 +223,7 @@ def patch_argparse():
sys.modules['argparse'].ArgumentParser._get_value = PatchArgumentParser._get_value
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
# Notice! we are patching argparser, so we know if someone parsed arguments before connecting to task
patch_argparse()

155
clearml/binding/click.py Normal file
View File

@ -0,0 +1,155 @@
try:
from click.core import Command, Option, Argument, OptionParser, Group, Context
except ImportError:
Command = None
from .frameworks import _patched_call
from ..config import running_remotely, get_remote_task_id
from ..utilities.dicts import cast_str_to_bool
class PatchClick:
_args = {}
_args_desc = {}
_args_type = {}
_num_commands = 0
_command_type = 'click.Command'
_section_name = 'Args'
_main_task = None
__remote_task_params = None
__remote_task_params_dict = {}
__patched = False
@classmethod
def patch(cls, task=None):
if Command is None:
return
if task:
cls._main_task = task
PatchClick._update_task_args()
if not cls.__patched:
cls.__patched = True
Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init)
Command.parse_args = _patched_call(Command.parse_args, PatchClick._parse_args)
Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init)
@classmethod
def args(cls):
# remove prefix and main command
if cls._num_commands == 1:
cmd = sorted(cls._args.keys())[0]
skip = len(cmd)+1
else:
skip = 0
_args = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args.items() if k[skip:]}
_args_type = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_type.items() if k[skip:]}
_args_desc = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_desc.items() if k[skip:]}
return _args, _args_type, _args_desc
@classmethod
def _update_task_args(cls):
if running_remotely() or not cls._main_task or not cls._args:
return
param_val, param_types, param_desc = cls.args()
# noinspection PyProtectedMember
cls._main_task._set_parameters(
param_val,
__update=True,
__parameters_descriptions=param_desc,
__parameters_types=param_types
)
@staticmethod
def _command_init(original_fn, self, *args, **kwargs):
if self and isinstance(self, Command) and 'name' in kwargs:
PatchClick._num_commands += 1
if running_remotely():
pass
else:
name = kwargs['name']
if name:
PatchClick._args[name] = False
PatchClick._args_type[name] = PatchClick._command_type
# maybe we should take it post initialization
if kwargs.get('help'):
PatchClick._args_desc[name] = str(kwargs.get('help'))
for option in kwargs.get('params') or []:
if not option or not isinstance(option, (Option, Argument)) or \
not getattr(option, 'expose_value', True):
continue
# store default value
PatchClick._args[name+'/'+option.name] = str(option.default or '')
# store value type
if option.type is not None:
PatchClick._args_type[name+'/'+option.name] = str(option.type)
# store value help
if option.help:
PatchClick._args_desc[name+'/'+option.name] = str(option.help)
return original_fn(self, *args, **kwargs)
@staticmethod
def _parse_args(original_fn, self, *args, **kwargs):
if running_remotely() and isinstance(self, Command) and isinstance(self, Group):
command = PatchClick._load_task_params()
if command:
init_args = kwargs['args'] if 'args' in kwargs else args[1]
init_args = [command] + (init_args[1:] if init_args else [])
if 'args' in kwargs:
kwargs['args'] = init_args
else:
args = (args[0], init_args, *args[2:])
ret = original_fn(self, *args, **kwargs)
if isinstance(self, Command) and not isinstance(self, Group):
ctx = kwargs.get('ctx') or args[0]
if running_remotely():
PatchClick._load_task_params()
for p in self.params:
name = '{}/{}'.format(self.name, p.name) if PatchClick._num_commands > 1 else p.name
value = PatchClick.__remote_task_params_dict.get(name)
ctx.params[p.name] = p.process_value(ctx, value)
else:
PatchClick._args[self.name] = True
for k, v in ctx.params.items():
# store passed value
PatchClick._args[self.name + '/' + str(k)] = str(v or '')
PatchClick._update_task_args()
return ret
@staticmethod
def _context_init(original_fn, self, *args, **kwargs):
if running_remotely():
kwargs['resilient_parsing'] = True
return original_fn(self, *args, **kwargs)
@staticmethod
def _load_task_params():
if not PatchClick.__remote_task_params:
from clearml import Task
t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember
PatchClick.__remote_task_params = t._get_task_property('hyperparams') or {}
params_dict = t.get_parameters(backwards_compatibility=False)
skip = len(PatchClick._section_name)+1
PatchClick.__remote_task_params_dict = {
k[skip:]: v for k, v in params_dict.items()
if k.startswith(PatchClick._section_name+'/')
}
params = PatchClick.__remote_task_params
command = [
p.name for p in params['Args'].values()
if p.type == PatchClick._command_type and cast_str_to_bool(p.value, strip=True)]
return command[0] if command else None
# patch click before anything
PatchClick.patch()

View File

@ -47,6 +47,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.joblib_bind import PatchedJoblib
from .binding.matplotlib_bind import PatchedMatplotlib
from .binding.hydra_bind import PatchHydra
from .binding.click import PatchClick
from .config import (
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, PROC_MASTER_ID_ENV_VAR,
DEV_DEFAULT_OUTPUT_URI, )
@ -58,7 +59,7 @@ from .logger import Logger
from .model import Model, InputModel, OutputModel
from .task_parameters import TaskParameters
from .utilities.config import verify_basic_value
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
from .binding.args import argparser_parseargs_called, get_argparser_last_args, \
argparser_update_currenttask
from .utilities.dicts import ReadOnlyDict, merge_dicts
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \
@ -578,6 +579,8 @@ class Task(_Task):
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
# Patch Click
PatchClick.patch(Task.__main_task)
# set excluded arguments
if isinstance(auto_connect_arg_parser, dict):
@ -2894,7 +2897,11 @@ class Task(_Task):
return
# shutdown will clear the main, so we have to store it before.
# is_main = self.is_main_task()
self.__shutdown()
# fix debugger signal in the middle
try:
self.__shutdown()
except:
pass
# In rare cases we might need to forcefully shutdown the process, currently we should avoid it.
# if is_main:
# # we have to forcefully shutdown if we have forked processes, sometimes they will get stuck

View File

@ -1,4 +1,5 @@
""" Utilities """
from typing import Optional, Any
_epsilon = 0.00001
@ -152,3 +153,18 @@ def hocon_unquote_key(a_dict):
else:
new_dict[k] = hocon_unquote_key(v)
return new_dict
def cast_str_to_bool(value, strip=True):
# type: (Any, bool) -> Optional[bool]
a_strip_v = value if not strip else str(value).lower().strip()
if a_strip_v == 'false' or not a_strip_v:
return False
elif a_strip_v == 'true':
return True
else:
# first try to cast to integer
try:
return bool(int(a_strip_v))
except (ValueError, TypeError):
return None

View File

@ -0,0 +1,34 @@
import click
from clearml import Task
@click.group()
def cli():
task = Task.init(project_name='examples', task_name='click multi command')
print('done')
@cli.command('hello', help='test help')
@click.option('--count', default=1, help='Number of greetings.')
@click.option('--name', prompt='Your name', help='The person to greet.')
def hello(count, name):
"""Simple program that greets NAME for a total of COUNT times."""
for x in range(count):
click.echo(f"Hello {name}!")
print('done')
CONTEXT_SETTINGS = dict(
default_map={'runserver': {'port': 5000}}
)
@cli.command('runserver')
@click.option('--port', default=8000)
@click.option('--name', help='service name')
def runserver(port, name):
click.echo(f"Serving on http://127.0.0.1:{port} {name}/")
if __name__ == '__main__':
cli()

View File

@ -0,0 +1,18 @@
import click
from clearml import Task
@click.command()
@click.option('--count', default=1, help='Number of greetings.')
@click.option('--name', prompt='Your name',
help='The person to greet.')
def hello(count, name):
task = Task.init(project_name='examples', task_name='click single command')
"""Simple program that greets NAME for a total of COUNT times."""
for x in range(count):
click.echo(f"Hello {name}!")
if __name__ == '__main__':
hello()