Ignore None keys in click args (#903)

This commit is contained in:
Zegang Cheng 2023-02-07 15:53:57 -05:00 committed by GitHub
parent b1f17db657
commit 269f5e7974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 33 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ dist/
examples/runs/ examples/runs/
examples/*_data examples/*_data
examples/frameworks/data examples/frameworks/data
venv/

View File

@ -14,8 +14,8 @@ class PatchClick:
_args_desc = {} _args_desc = {}
_args_type = {} _args_type = {}
_num_commands = 0 _num_commands = 0
_command_type = 'click.Command' _command_type = "click.Command"
_section_name = 'Args' _section_name = "Args"
_current_task = None _current_task = None
__remote_task_params = None __remote_task_params = None
__remote_task_params_dict = {} __remote_task_params_dict = {}
@ -33,21 +33,38 @@ class PatchClick:
if not cls.__patched: if not cls.__patched:
cls.__patched = True cls.__patched = True
Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init) Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init)
Command.parse_args = _patched_call(Command.parse_args, PatchClick._parse_args) Command.parse_args = _patched_call(
Command.parse_args, PatchClick._parse_args
)
Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init) Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init)
@classmethod @classmethod
def args(cls): def args(cls):
# ignore None keys
cls._args = {k: v for k, v in cls._args.items() if k is not None}
# remove prefix and main command # remove prefix and main command
if cls._num_commands == 1: if cls._num_commands == 1:
cmd = sorted(cls._args.keys())[0] cmd = sorted(cls._args.keys())[0]
skip = len(cmd)+1 skip = len(cmd) + 1
else: else:
skip = 0 skip = 0
_args = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args.items() if k[skip:]} _args = {
_args_type = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_type.items() if k[skip:]} cls._section_name + "/" + k[skip:]: v
_args_desc = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_desc.items() if k[skip:]} for k, v in cls._args.items()
if k[skip:]
}
_args_type = {
cls._section_name + "/" + k[skip:]: v
for k, v in cls._args_type.items()
if k[skip:]
}
_args_desc = {
cls._section_name + "/" + k[skip:]: v
for k, v in cls._args_desc.items()
if k[skip:]
}
return _args, _args_type, _args_desc return _args, _args_type, _args_desc
@ -61,36 +78,45 @@ class PatchClick:
param_val, param_val,
__update=True, __update=True,
__parameters_descriptions=param_desc, __parameters_descriptions=param_desc,
__parameters_types=param_types __parameters_types=param_types,
) )
@staticmethod @staticmethod
def _command_init(original_fn, self, *args, **kwargs): def _command_init(original_fn, self, *args, **kwargs):
if isinstance(self, (Command, Group)) and 'name' in kwargs: if isinstance(self, (Command, Group)) and "name" in kwargs:
if isinstance(self, Command): if isinstance(self, Command):
PatchClick._num_commands += 1 PatchClick._num_commands += 1
if not running_remotely(): if not running_remotely():
name = kwargs['name'] name = kwargs["name"]
if name: if name:
PatchClick._args[name] = False PatchClick._args[name] = False
if isinstance(self, Command): if isinstance(self, Command):
PatchClick._args_type[name] = PatchClick._command_type PatchClick._args_type[name] = PatchClick._command_type
# maybe we should take it post initialization # maybe we should take it post initialization
if kwargs.get('help'): if kwargs.get("help"):
PatchClick._args_desc[name] = str(kwargs.get('help')) PatchClick._args_desc[name] = str(kwargs.get("help"))
for option in kwargs.get('params') or []: for option in kwargs.get("params") or []:
if not option or not isinstance(option, (Option, Argument)) or \ if (
not getattr(option, 'expose_value', True): not option
or not isinstance(option, (Option, Argument))
or not getattr(option, "expose_value", True)
):
continue continue
# store default value # store default value
PatchClick._args[name+'/'+option.name] = str(option.default or '') PatchClick._args[name + "/" + option.name] = str(
option.default or ""
)
# store value type # store value type
if option.type is not None: if option.type is not None:
PatchClick._args_type[name+'/'+option.name] = str(option.type) PatchClick._args_type[name + "/" + option.name] = str(
option.type
)
# store value help # store value help
if getattr(option, 'help', None): if getattr(option, "help", None):
PatchClick._args_desc[name+'/'+option.name] = str(option.help) PatchClick._args_desc[name + "/" + option.name] = str(
option.help
)
return original_fn(self, *args, **kwargs) return original_fn(self, *args, **kwargs)
@ -99,30 +125,38 @@ class PatchClick:
if running_remotely() and isinstance(self, Command) and isinstance(self, Group): if running_remotely() and isinstance(self, Command) and isinstance(self, Group):
command = PatchClick._load_task_params() command = PatchClick._load_task_params()
if command: if command:
init_args = kwargs['args'] if 'args' in kwargs else args[1] init_args = kwargs["args"] if "args" in kwargs else args[1]
init_args = [command] + (init_args[1:] if init_args else []) init_args = [command] + (init_args[1:] if init_args else [])
if 'args' in kwargs: if "args" in kwargs:
kwargs['args'] = init_args kwargs["args"] = init_args
else: else:
args = (args[0], init_args) + args[2:] args = (args[0], init_args) + args[2:]
ret = original_fn(self, *args, **kwargs) ret = original_fn(self, *args, **kwargs)
if isinstance(self, Command): if isinstance(self, Command):
ctx = kwargs.get('ctx') or args[0] ctx = kwargs.get("ctx") or args[0]
if running_remotely(): if running_remotely():
PatchClick._load_task_params() PatchClick._load_task_params()
for p in self.params: for p in self.params:
name = '{}/{}'.format(self.name, p.name) if PatchClick._num_commands > 1 else p.name name = (
"{}/{}".format(self.name, p.name)
if PatchClick._num_commands > 1
else p.name
)
value = PatchClick.__remote_task_params_dict.get(name) value = PatchClick.__remote_task_params_dict.get(name)
ctx.params[p.name] = p.process_value( ctx.params[p.name] = p.process_value(
ctx, cast_str_to_bool(value, strip=True) if isinstance(p.type, BoolParamType) else value) ctx,
cast_str_to_bool(value, strip=True)
if isinstance(p.type, BoolParamType)
else value,
)
else: else:
if not isinstance(self, Group): if not isinstance(self, Group):
PatchClick._args[self.name] = True PatchClick._args[self.name] = True
for k, v in ctx.params.items(): for k, v in ctx.params.items():
# store passed value # store passed value
PatchClick._args[self.name + '/' + str(k)] = str(v or '') PatchClick._args[self.name + "/" + str(k)] = str(v or "")
PatchClick._update_task_args() PatchClick._update_task_args()
return ret return ret
@ -130,29 +164,34 @@ class PatchClick:
@staticmethod @staticmethod
def _context_init(original_fn, self, *args, **kwargs): def _context_init(original_fn, self, *args, **kwargs):
if running_remotely(): if running_remotely():
kwargs['resilient_parsing'] = True kwargs["resilient_parsing"] = True
return original_fn(self, *args, **kwargs) return original_fn(self, *args, **kwargs)
@staticmethod @staticmethod
def _load_task_params(): def _load_task_params():
if not PatchClick.__remote_task_params: if not PatchClick.__remote_task_params:
from clearml import Task from clearml import Task
t = Task.get_task(task_id=get_remote_task_id()) t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember # noinspection PyProtectedMember
PatchClick.__remote_task_params = t._get_task_property('hyperparams') or {} PatchClick.__remote_task_params = t._get_task_property("hyperparams") or {}
params_dict = t.get_parameters(backwards_compatibility=False) params_dict = t.get_parameters(backwards_compatibility=False)
skip = len(PatchClick._section_name)+1 skip = len(PatchClick._section_name) + 1
PatchClick.__remote_task_params_dict = { PatchClick.__remote_task_params_dict = {
k[skip:]: v for k, v in params_dict.items() k[skip:]: v
if k.startswith(PatchClick._section_name+'/') for k, v in params_dict.items()
if k.startswith(PatchClick._section_name + "/")
} }
params = PatchClick.__remote_task_params params = PatchClick.__remote_task_params
if not params: if not params:
return None return None
command = [ command = [
p.name for p in params['Args'].values() p.name
if p.type == PatchClick._command_type and cast_str_to_bool(p.value, strip=True)] for p in params["Args"].values()
if p.type == PatchClick._command_type
and cast_str_to_bool(p.value, strip=True)
]
return command[0] if command else None return command[0] if command else None