mirror of
https://github.com/clearml/clearml
synced 2025-03-13 07:08:24 +00:00
Ignore None
keys in click args (#903)
This commit is contained in:
parent
b1f17db657
commit
269f5e7974
1
.gitignore
vendored
1
.gitignore
vendored
@ -16,3 +16,4 @@ dist/
|
||||
examples/runs/
|
||||
examples/*_data
|
||||
examples/frameworks/data
|
||||
venv/
|
@ -14,8 +14,8 @@ class PatchClick:
|
||||
_args_desc = {}
|
||||
_args_type = {}
|
||||
_num_commands = 0
|
||||
_command_type = 'click.Command'
|
||||
_section_name = 'Args'
|
||||
_command_type = "click.Command"
|
||||
_section_name = "Args"
|
||||
_current_task = None
|
||||
__remote_task_params = None
|
||||
__remote_task_params_dict = {}
|
||||
@ -33,21 +33,38 @@ class PatchClick:
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
Command.__init__ = _patched_call(Command.__init__, PatchClick._command_init)
|
||||
Command.parse_args = _patched_call(Command.parse_args, PatchClick._parse_args)
|
||||
Command.parse_args = _patched_call(
|
||||
Command.parse_args, PatchClick._parse_args
|
||||
)
|
||||
Context.__init__ = _patched_call(Context.__init__, PatchClick._context_init)
|
||||
|
||||
@classmethod
|
||||
def args(cls):
|
||||
# ignore None keys
|
||||
cls._args = {k: v for k, v in cls._args.items() if k is not None}
|
||||
|
||||
# remove prefix and main command
|
||||
if cls._num_commands == 1:
|
||||
cmd = sorted(cls._args.keys())[0]
|
||||
skip = len(cmd)+1
|
||||
skip = len(cmd) + 1
|
||||
else:
|
||||
skip = 0
|
||||
|
||||
_args = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args.items() if k[skip:]}
|
||||
_args_type = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_type.items() if k[skip:]}
|
||||
_args_desc = {cls._section_name+'/'+k[skip:]: v for k, v in cls._args_desc.items() if k[skip:]}
|
||||
_args = {
|
||||
cls._section_name + "/" + k[skip:]: v
|
||||
for k, v in cls._args.items()
|
||||
if k[skip:]
|
||||
}
|
||||
_args_type = {
|
||||
cls._section_name + "/" + k[skip:]: v
|
||||
for k, v in cls._args_type.items()
|
||||
if k[skip:]
|
||||
}
|
||||
_args_desc = {
|
||||
cls._section_name + "/" + k[skip:]: v
|
||||
for k, v in cls._args_desc.items()
|
||||
if k[skip:]
|
||||
}
|
||||
|
||||
return _args, _args_type, _args_desc
|
||||
|
||||
@ -61,36 +78,45 @@ class PatchClick:
|
||||
param_val,
|
||||
__update=True,
|
||||
__parameters_descriptions=param_desc,
|
||||
__parameters_types=param_types
|
||||
__parameters_types=param_types,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _command_init(original_fn, self, *args, **kwargs):
|
||||
if isinstance(self, (Command, Group)) and 'name' in kwargs:
|
||||
if isinstance(self, (Command, Group)) and "name" in kwargs:
|
||||
if isinstance(self, Command):
|
||||
PatchClick._num_commands += 1
|
||||
if not running_remotely():
|
||||
name = kwargs['name']
|
||||
name = kwargs["name"]
|
||||
if name:
|
||||
PatchClick._args[name] = False
|
||||
if isinstance(self, Command):
|
||||
PatchClick._args_type[name] = PatchClick._command_type
|
||||
# maybe we should take it post initialization
|
||||
if kwargs.get('help'):
|
||||
PatchClick._args_desc[name] = str(kwargs.get('help'))
|
||||
if kwargs.get("help"):
|
||||
PatchClick._args_desc[name] = str(kwargs.get("help"))
|
||||
|
||||
for option in kwargs.get('params') or []:
|
||||
if not option or not isinstance(option, (Option, Argument)) or \
|
||||
not getattr(option, 'expose_value', True):
|
||||
for option in kwargs.get("params") or []:
|
||||
if (
|
||||
not option
|
||||
or not isinstance(option, (Option, Argument))
|
||||
or not getattr(option, "expose_value", True)
|
||||
):
|
||||
continue
|
||||
# store default value
|
||||
PatchClick._args[name+'/'+option.name] = str(option.default or '')
|
||||
PatchClick._args[name + "/" + option.name] = str(
|
||||
option.default or ""
|
||||
)
|
||||
# store value type
|
||||
if option.type is not None:
|
||||
PatchClick._args_type[name+'/'+option.name] = str(option.type)
|
||||
PatchClick._args_type[name + "/" + option.name] = str(
|
||||
option.type
|
||||
)
|
||||
# store value help
|
||||
if getattr(option, 'help', None):
|
||||
PatchClick._args_desc[name+'/'+option.name] = str(option.help)
|
||||
if getattr(option, "help", None):
|
||||
PatchClick._args_desc[name + "/" + option.name] = str(
|
||||
option.help
|
||||
)
|
||||
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
@ -99,30 +125,38 @@ class PatchClick:
|
||||
if running_remotely() and isinstance(self, Command) and isinstance(self, Group):
|
||||
command = PatchClick._load_task_params()
|
||||
if command:
|
||||
init_args = kwargs['args'] if 'args' in kwargs else args[1]
|
||||
init_args = kwargs["args"] if "args" in kwargs else args[1]
|
||||
init_args = [command] + (init_args[1:] if init_args else [])
|
||||
if 'args' in kwargs:
|
||||
kwargs['args'] = init_args
|
||||
if "args" in kwargs:
|
||||
kwargs["args"] = init_args
|
||||
else:
|
||||
args = (args[0], init_args) + args[2:]
|
||||
|
||||
ret = original_fn(self, *args, **kwargs)
|
||||
|
||||
if isinstance(self, Command):
|
||||
ctx = kwargs.get('ctx') or args[0]
|
||||
ctx = kwargs.get("ctx") or args[0]
|
||||
if running_remotely():
|
||||
PatchClick._load_task_params()
|
||||
for p in self.params:
|
||||
name = '{}/{}'.format(self.name, p.name) if PatchClick._num_commands > 1 else p.name
|
||||
name = (
|
||||
"{}/{}".format(self.name, p.name)
|
||||
if PatchClick._num_commands > 1
|
||||
else p.name
|
||||
)
|
||||
value = PatchClick.__remote_task_params_dict.get(name)
|
||||
ctx.params[p.name] = p.process_value(
|
||||
ctx, cast_str_to_bool(value, strip=True) if isinstance(p.type, BoolParamType) else value)
|
||||
ctx,
|
||||
cast_str_to_bool(value, strip=True)
|
||||
if isinstance(p.type, BoolParamType)
|
||||
else value,
|
||||
)
|
||||
else:
|
||||
if not isinstance(self, Group):
|
||||
PatchClick._args[self.name] = True
|
||||
for k, v in ctx.params.items():
|
||||
# store passed value
|
||||
PatchClick._args[self.name + '/' + str(k)] = str(v or '')
|
||||
PatchClick._args[self.name + "/" + str(k)] = str(v or "")
|
||||
|
||||
PatchClick._update_task_args()
|
||||
return ret
|
||||
@ -130,29 +164,34 @@ class PatchClick:
|
||||
@staticmethod
|
||||
def _context_init(original_fn, self, *args, **kwargs):
|
||||
if running_remotely():
|
||||
kwargs['resilient_parsing'] = True
|
||||
kwargs["resilient_parsing"] = True
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _load_task_params():
|
||||
if not PatchClick.__remote_task_params:
|
||||
from clearml import Task
|
||||
|
||||
t = Task.get_task(task_id=get_remote_task_id())
|
||||
# noinspection PyProtectedMember
|
||||
PatchClick.__remote_task_params = t._get_task_property('hyperparams') or {}
|
||||
PatchClick.__remote_task_params = t._get_task_property("hyperparams") or {}
|
||||
params_dict = t.get_parameters(backwards_compatibility=False)
|
||||
skip = len(PatchClick._section_name)+1
|
||||
skip = len(PatchClick._section_name) + 1
|
||||
PatchClick.__remote_task_params_dict = {
|
||||
k[skip:]: v for k, v in params_dict.items()
|
||||
if k.startswith(PatchClick._section_name+'/')
|
||||
k[skip:]: v
|
||||
for k, v in params_dict.items()
|
||||
if k.startswith(PatchClick._section_name + "/")
|
||||
}
|
||||
|
||||
params = PatchClick.__remote_task_params
|
||||
if not params:
|
||||
return None
|
||||
command = [
|
||||
p.name for p in params['Args'].values()
|
||||
if p.type == PatchClick._command_type and cast_str_to_bool(p.value, strip=True)]
|
||||
p.name
|
||||
for p in params["Args"].values()
|
||||
if p.type == PatchClick._command_type
|
||||
and cast_str_to_bool(p.value, strip=True)
|
||||
]
|
||||
return command[0] if command else None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user