mirror of
https://github.com/clearml/clearml
synced 2025-04-03 20:41:07 +00:00
Add Click support (issue #386)
This commit is contained in:
parent
c42e32e137
commit
226c68330e
@ -6,7 +6,7 @@ from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsers
|
||||
from copy import copy
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...utilities.args import call_original_argparser
|
||||
from ...binding.args import call_original_argparser
|
||||
|
||||
|
||||
class _Arguments(object):
|
||||
|
@ -223,7 +223,7 @@ def patch_argparse():
|
||||
sys.modules['argparse'].ArgumentParser._get_value = PatchArgumentParser._get_value
|
||||
|
||||
|
||||
# Notice! we are patching argparser, sop we know if someone parsed arguments before connecting to task
|
||||
# Notice! we are patching argparser, so we know if someone parsed arguments before connecting to task
|
||||
patch_argparse()
|
||||
|
||||
|
155
clearml/binding/click.py
Normal file
155
clearml/binding/click.py
Normal file
@ -0,0 +1,155 @@
|
||||
try:
|
||||
from click.core import Command, Option, Argument, OptionParser, Group, Context
|
||||
except ImportError:
|
||||
Command = None
|
||||
|
||||
from .frameworks import _patched_call
|
||||
from ..config import running_remotely, get_remote_task_id
|
||||
from ..utilities.dicts import cast_str_to_bool
|
||||
|
||||
|
||||
class PatchClick:
|
||||
_args = {}
|
||||
_args_desc = {}
|
||||
_args_type = {}
|
||||
_num_commands = 0
|
||||
_command_type = 'click.Command'
|
||||
_section_name = 'Args'
|
||||
_main_task = None
|
||||
__remote_task_params = None
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
|
||||
@classmethod
|
||||
def patch(cls, task=None):
|
||||
if Command is None:
|
||||
return
|
||||
|
||||
if task:
|
||||
cls._main_task = task
|
||||
PatchClick._update_task_args()
|
||||
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init)
|
||||
Command.parse_args = _patched_call(Command.parse_args, PatchClick._parse_args)
|
||||
Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init)
|
||||
|
||||
@classmethod
|
||||
def args(cls):
|
||||
# remove prefix and main command
|
||||
if cls._num_commands == 1:
|
||||
cmd = sorted(cls._args.keys())[0]
|
||||
skip = len(cmd)+1
|
||||
else:
|
||||
skip = 0
|
||||
|
||||
_args = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args.items() if k[skip:]}
|
||||
_args_type = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_type.items() if k[skip:]}
|
||||
_args_desc = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_desc.items() if k[skip:]}
|
||||
|
||||
return _args, _args_type, _args_desc
|
||||
|
||||
@classmethod
|
||||
def _update_task_args(cls):
|
||||
if running_remotely() or not cls._main_task or not cls._args:
|
||||
return
|
||||
param_val, param_types, param_desc = cls.args()
|
||||
# noinspection PyProtectedMember
|
||||
cls._main_task._set_parameters(
|
||||
param_val,
|
||||
__update=True,
|
||||
__parameters_descriptions=param_desc,
|
||||
__parameters_types=param_types
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _command_init(original_fn, self, *args, **kwargs):
|
||||
if self and isinstance(self, Command) and 'name' in kwargs:
|
||||
PatchClick._num_commands += 1
|
||||
if running_remotely():
|
||||
pass
|
||||
else:
|
||||
name = kwargs['name']
|
||||
if name:
|
||||
PatchClick._args[name] = False
|
||||
PatchClick._args_type[name] = PatchClick._command_type
|
||||
# maybe we should take it post initialization
|
||||
if kwargs.get('help'):
|
||||
PatchClick._args_desc[name] = str(kwargs.get('help'))
|
||||
|
||||
for option in kwargs.get('params') or []:
|
||||
if not option or not isinstance(option, (Option, Argument)) or \
|
||||
not getattr(option, 'expose_value', True):
|
||||
continue
|
||||
# store default value
|
||||
PatchClick._args[name+'/'+option.name] = str(option.default or '')
|
||||
# store value type
|
||||
if option.type is not None:
|
||||
PatchClick._args_type[name+'/'+option.name] = str(option.type)
|
||||
# store value help
|
||||
if option.help:
|
||||
PatchClick._args_desc[name+'/'+option.name] = str(option.help)
|
||||
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(original_fn, self, *args, **kwargs):
|
||||
if running_remotely() and isinstance(self, Command) and isinstance(self, Group):
|
||||
command = PatchClick._load_task_params()
|
||||
if command:
|
||||
init_args = kwargs['args'] if 'args' in kwargs else args[1]
|
||||
init_args = [command] + (init_args[1:] if init_args else [])
|
||||
if 'args' in kwargs:
|
||||
kwargs['args'] = init_args
|
||||
else:
|
||||
args = (args[0], init_args, *args[2:])
|
||||
|
||||
ret = original_fn(self, *args, **kwargs)
|
||||
|
||||
if isinstance(self, Command) and not isinstance(self, Group):
|
||||
ctx = kwargs.get('ctx') or args[0]
|
||||
if running_remotely():
|
||||
PatchClick._load_task_params()
|
||||
for p in self.params:
|
||||
name = '{}/{}'.format(self.name, p.name) if PatchClick._num_commands > 1 else p.name
|
||||
value = PatchClick.__remote_task_params_dict.get(name)
|
||||
ctx.params[p.name] = p.process_value(ctx, value)
|
||||
else:
|
||||
PatchClick._args[self.name] = True
|
||||
for k, v in ctx.params.items():
|
||||
# store passed value
|
||||
PatchClick._args[self.name + '/' + str(k)] = str(v or '')
|
||||
|
||||
PatchClick._update_task_args()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _context_init(original_fn, self, *args, **kwargs):
|
||||
if running_remotely():
|
||||
kwargs['resilient_parsing'] = True
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _load_task_params():
|
||||
if not PatchClick.__remote_task_params:
|
||||
from clearml import Task
|
||||
t = Task.get_task(task_id=get_remote_task_id())
|
||||
# noinspection PyProtectedMember
|
||||
PatchClick.__remote_task_params = t._get_task_property('hyperparams') or {}
|
||||
params_dict = t.get_parameters(backwards_compatibility=False)
|
||||
skip = len(PatchClick._section_name)+1
|
||||
PatchClick.__remote_task_params_dict = {
|
||||
k[skip:]: v for k, v in params_dict.items()
|
||||
if k.startswith(PatchClick._section_name+'/')
|
||||
}
|
||||
|
||||
params = PatchClick.__remote_task_params
|
||||
command = [
|
||||
p.name for p in params['Args'].values()
|
||||
if p.type == PatchClick._command_type and cast_str_to_bool(p.value, strip=True)]
|
||||
return command[0] if command else None
|
||||
|
||||
|
||||
# patch click before anything
|
||||
PatchClick.patch()
|
@ -47,6 +47,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.joblib_bind import PatchedJoblib
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .binding.hydra_bind import PatchHydra
|
||||
from .binding.click import PatchClick
|
||||
from .config import (
|
||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, PROC_MASTER_ID_ENV_VAR,
|
||||
DEV_DEFAULT_OUTPUT_URI, )
|
||||
@ -58,7 +59,7 @@ from .logger import Logger
|
||||
from .model import Model, InputModel, OutputModel
|
||||
from .task_parameters import TaskParameters
|
||||
from .utilities.config import verify_basic_value
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
from .binding.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
argparser_update_currenttask
|
||||
from .utilities.dicts import ReadOnlyDict, merge_dicts
|
||||
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \
|
||||
@ -578,6 +579,8 @@ class Task(_Task):
|
||||
|
||||
# Patch ArgParser to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
# Patch Click
|
||||
PatchClick.patch(Task.__main_task)
|
||||
|
||||
# set excluded arguments
|
||||
if isinstance(auto_connect_arg_parser, dict):
|
||||
@ -2894,7 +2897,11 @@ class Task(_Task):
|
||||
return
|
||||
# shutdown will clear the main, so we have to store it before.
|
||||
# is_main = self.is_main_task()
|
||||
self.__shutdown()
|
||||
# fix debugger signal in the middle
|
||||
try:
|
||||
self.__shutdown()
|
||||
except:
|
||||
pass
|
||||
# In rare cases we might need to forcefully shutdown the process, currently we should avoid it.
|
||||
# if is_main:
|
||||
# # we have to forcefully shutdown if we have forked processes, sometimes they will get stuck
|
||||
|
@ -1,4 +1,5 @@
|
||||
""" Utilities """
|
||||
from typing import Optional, Any
|
||||
|
||||
_epsilon = 0.00001
|
||||
|
||||
@ -152,3 +153,18 @@ def hocon_unquote_key(a_dict):
|
||||
else:
|
||||
new_dict[k] = hocon_unquote_key(v)
|
||||
return new_dict
|
||||
|
||||
|
||||
def cast_str_to_bool(value, strip=True):
|
||||
# type: (Any, bool) -> Optional[bool]
|
||||
a_strip_v = value if not strip else str(value).lower().strip()
|
||||
if a_strip_v == 'false' or not a_strip_v:
|
||||
return False
|
||||
elif a_strip_v == 'true':
|
||||
return True
|
||||
else:
|
||||
# first try to cast to integer
|
||||
try:
|
||||
return bool(int(a_strip_v))
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
34
examples/frameworks/click/click_multi_cmd.py
Normal file
34
examples/frameworks/click/click_multi_cmd.py
Normal file
@ -0,0 +1,34 @@
|
||||
import click
|
||||
from clearml import Task
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
task = Task.init(project_name='examples', task_name='click multi command')
|
||||
print('done')
|
||||
|
||||
|
||||
@cli.command('hello', help='test help')
|
||||
@click.option('--count', default=1, help='Number of greetings.')
|
||||
@click.option('--name', prompt='Your name', help='The person to greet.')
|
||||
def hello(count, name):
|
||||
"""Simple program that greets NAME for a total of COUNT times."""
|
||||
for x in range(count):
|
||||
click.echo(f"Hello {name}!")
|
||||
print('done')
|
||||
|
||||
|
||||
CONTEXT_SETTINGS = dict(
|
||||
default_map={'runserver': {'port': 5000}}
|
||||
)
|
||||
|
||||
|
||||
@cli.command('runserver')
|
||||
@click.option('--port', default=8000)
|
||||
@click.option('--name', help='service name')
|
||||
def runserver(port, name):
|
||||
click.echo(f"Serving on http://127.0.0.1:{port} {name}/")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
18
examples/frameworks/click/click_single_cmd.py
Normal file
18
examples/frameworks/click/click_single_cmd.py
Normal file
@ -0,0 +1,18 @@
|
||||
import click
|
||||
from clearml import Task
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--count', default=1, help='Number of greetings.')
|
||||
@click.option('--name', prompt='Your name',
|
||||
help='The person to greet.')
|
||||
def hello(count, name):
|
||||
task = Task.init(project_name='examples', task_name='click single command')
|
||||
|
||||
"""Simple program that greets NAME for a total of COUNT times."""
|
||||
for x in range(count):
|
||||
click.echo(f"Hello {name}!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hello()
|
Loading…
Reference in New Issue
Block a user