Ignore None keys in click args (#903)

This commit is contained in:
Zegang Cheng 2023-02-07 15:53:57 -05:00 committed by GitHub
parent b1f17db657
commit 269f5e7974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 33 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ dist/
examples/runs/
examples/*_data
examples/frameworks/data
venv/

View File

@ -14,8 +14,8 @@ class PatchClick:
_args_desc = {}
_args_type = {}
_num_commands = 0
_command_type = 'click.Command'
_section_name = 'Args'
_command_type = "click.Command"
_section_name = "Args"
_current_task = None
__remote_task_params = None
__remote_task_params_dict = {}
@ -33,21 +33,38 @@ class PatchClick:
if not cls.__patched:
cls.__patched = True
Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init)
Command.parse_args = _patched_call(Command.parse_args, PatchClick._parse_args)
Command.parse_args = _patched_call(
Command.parse_args, PatchClick._parse_args
)
Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init)
@classmethod
def args(cls):
# ignore None keys
cls._args = {k: v for k, v in cls._args.items() if k is not None}
# remove prefix and main command
if cls._num_commands == 1:
cmd = sorted(cls._args.keys())[0]
skip = len(cmd)+1
skip = len(cmd) + 1
else:
skip = 0
_args = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args.items() if k[skip:]}
_args_type = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_type.items() if k[skip:]}
_args_desc = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_desc.items() if k[skip:]}
_args = {
cls._section_name + "/" + k[skip:]: v
for k, v in cls._args.items()
if k[skip:]
}
_args_type = {
cls._section_name + "/" + k[skip:]: v
for k, v in cls._args_type.items()
if k[skip:]
}
_args_desc = {
cls._section_name + "/" + k[skip:]: v
for k, v in cls._args_desc.items()
if k[skip:]
}
return _args, _args_type, _args_desc
@ -61,36 +78,45 @@ class PatchClick:
param_val,
__update=True,
__parameters_descriptions=param_desc,
__parameters_types=param_types
__parameters_types=param_types,
)
@staticmethod
def _command_init(original_fn, self, *args, **kwargs):
if isinstance(self, (Command, Group)) and 'name' in kwargs:
if isinstance(self, (Command, Group)) and "name" in kwargs:
if isinstance(self, Command):
PatchClick._num_commands += 1
if not running_remotely():
name = kwargs['name']
name = kwargs["name"]
if name:
PatchClick._args[name] = False
if isinstance(self, Command):
PatchClick._args_type[name] = PatchClick._command_type
# maybe we should take it post initialization
if kwargs.get('help'):
PatchClick._args_desc[name] = str(kwargs.get('help'))
if kwargs.get("help"):
PatchClick._args_desc[name] = str(kwargs.get("help"))
for option in kwargs.get('params') or []:
if not option or not isinstance(option, (Option, Argument)) or \
not getattr(option, 'expose_value', True):
for option in kwargs.get("params") or []:
if (
not option
or not isinstance(option, (Option, Argument))
or not getattr(option, "expose_value", True)
):
continue
# store default value
PatchClick._args[name+'/'+option.name] = str(option.default or '')
PatchClick._args[name + "/" + option.name] = str(
option.default or ""
)
# store value type
if option.type is not None:
PatchClick._args_type[name+'/'+option.name] = str(option.type)
PatchClick._args_type[name + "/" + option.name] = str(
option.type
)
# store value help
if getattr(option, 'help', None):
PatchClick._args_desc[name+'/'+option.name] = str(option.help)
if getattr(option, "help", None):
PatchClick._args_desc[name + "/" + option.name] = str(
option.help
)
return original_fn(self, *args, **kwargs)
@ -99,30 +125,38 @@ class PatchClick:
if running_remotely() and isinstance(self, Command) and isinstance(self, Group):
command = PatchClick._load_task_params()
if command:
init_args = kwargs['args'] if 'args' in kwargs else args[1]
init_args = kwargs["args"] if "args" in kwargs else args[1]
init_args = [command] + (init_args[1:] if init_args else [])
if 'args' in kwargs:
kwargs['args'] = init_args
if "args" in kwargs:
kwargs["args"] = init_args
else:
args = (args[0], init_args) + args[2:]
ret = original_fn(self, *args, **kwargs)
if isinstance(self, Command):
ctx = kwargs.get('ctx') or args[0]
ctx = kwargs.get("ctx") or args[0]
if running_remotely():
PatchClick._load_task_params()
for p in self.params:
name = '{}/{}'.format(self.name, p.name) if PatchClick._num_commands > 1 else p.name
name = (
"{}/{}".format(self.name, p.name)
if PatchClick._num_commands > 1
else p.name
)
value = PatchClick.__remote_task_params_dict.get(name)
ctx.params[p.name] = p.process_value(
ctx, cast_str_to_bool(value, strip=True) if isinstance(p.type, BoolParamType) else value)
ctx,
cast_str_to_bool(value, strip=True)
if isinstance(p.type, BoolParamType)
else value,
)
else:
if not isinstance(self, Group):
PatchClick._args[self.name] = True
for k, v in ctx.params.items():
# store passed value
PatchClick._args[self.name + '/' + str(k)] = str(v or '')
PatchClick._args[self.name + "/" + str(k)] = str(v or "")
PatchClick._update_task_args()
return ret
@ -130,29 +164,34 @@ class PatchClick:
@staticmethod
def _context_init(original_fn, self, *args, **kwargs):
if running_remotely():
kwargs['resilient_parsing'] = True
kwargs["resilient_parsing"] = True
return original_fn(self, *args, **kwargs)
@staticmethod
def _load_task_params():
if not PatchClick.__remote_task_params:
from clearml import Task
t = Task.get_task(task_id=get_remote_task_id())
# noinspection PyProtectedMember
PatchClick.__remote_task_params = t._get_task_property('hyperparams') or {}
PatchClick.__remote_task_params = t._get_task_property("hyperparams") or {}
params_dict = t.get_parameters(backwards_compatibility=False)
skip = len(PatchClick._section_name)+1
skip = len(PatchClick._section_name) + 1
PatchClick.__remote_task_params_dict = {
k[skip:]: v for k, v in params_dict.items()
if k.startswith(PatchClick._section_name+'/')
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(PatchClick._section_name + "/")
}
params = PatchClick.__remote_task_params
if not params:
return None
command = [
p.name for p in params['Args'].values()
if p.type == PatchClick._command_type and cast_str_to_bool(p.value, strip=True)]
p.name
for p in params["Args"].values()
if p.type == PatchClick._command_type
and cast_str_to_bool(p.value, strip=True)
]
return command[0] if command else None