From e59d8bdac3d01272d7aaca6dc6e06eba1c6ab7e9 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 2 Oct 2021 21:43:36 +0300 Subject: [PATCH] Fix task.connect(dict) value casting, if None is the default value, use backend stored type --- clearml/backend_interface/task/args.py | 22 ++++++++++++++++++++-- clearml/backend_interface/task/task.py | 14 ++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/clearml/backend_interface/task/args.py b/clearml/backend_interface/task/args.py index 45eeb317..1a8eabab 100644 --- a/clearml/backend_interface/task/args.py +++ b/clearml/backend_interface/task/args.py @@ -464,7 +464,15 @@ class _Arguments(object): prefix = prefix.strip(self._prefix_sep) + self._prefix_sep parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items() if k.startswith(prefix)]) + # noinspection PyProtectedMember + parameters_type = { + k: p.type + for k, p in ((self._task._get_task_property('hyperparams', raise_on_error=False) or {}).get( + prefix[:-len(self._prefix_sep)]) or {}).items() + if p.type + } else: + parameters_type = {} parameters = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(self._prefix_tf_defines)]) @@ -478,10 +486,20 @@ class _Arguments(object): param = parameters.get(k, None) if param is None: continue - v_type = type(v) + + # if default value is not specified, allow casting based on what we have on the Task + if v is not None: + v_type = type(v) + elif parameters_type.get(k): + v_type_str = parameters_type.get(k) + v_type = next((t for t in (bool, int, float, str, list, tuple) if t.__name__ == v_type_str), str) + else: + # this will be type(None), we deal with it later + v_type = type(v) + # assume more general purpose type int -> float if v_type == int: - if int(v) != float(v): + if v is not None and int(v) != float(v): v_type = float elif v_type == bool: # cast based on string or int diff --git a/clearml/backend_interface/task/task.py b/clearml/backend_interface/task/task.py index e1edf216..a1cb4893 100644 --- a/clearml/backend_interface/task/task.py +++ b/clearml/backend_interface/task/task.py @@ -1016,9 +1016,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): parameters.update(new_parameters) - # force cast all variables to strings (so that we can later edit them in UI) - parameters = {k: stringify(v) for k, v in parameters.items()} - if use_hyperparams: # build nested dict from flat parameters dict: org_hyperparams = self.data.hyperparams or {} @@ -1032,7 +1029,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy': section = hyperparams.get(legacy_name, dict()) section[k] = copy(org_legacy_section[k]) - section[k].value = str(v) if v else v + section[k].value = stringify(v) description = descriptions.get(k) if description: section[k].description = description @@ -1045,13 +1042,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): section_name, key = k.split('/', 1) section = hyperparams.get(section_name, dict()) org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem()) - param_type = params_types[org_k] if org_k in params_types else org_param.type + param_type = params_types[org_k] if org_k in params_types else ( + org_param.type or (type(v) if v is not None else None) + ) if param_type and not isinstance(param_type, str): param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type) section[key] = tasks.ParamsItem( section=section_name, name=key, - value=str(v) if v else v, + value=stringify(v), description=descriptions[org_k] if org_k in descriptions else org_param.description, type=param_type, ) @@ -1060,6 +1059,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._edit(hyperparams=hyperparams) self.data.hyperparams = hyperparams else: + # force cast all variables to strings (so that we can later edit them in UI) + parameters = {k: stringify(v) for k, v in parameters.items()} + execution = self.data.execution if execution is None: execution = tasks.Execution(