Fix task.connect(dict) value casting, if None is the default value, use backend stored type

This commit is contained in:
allegroai 2021-10-02 21:43:36 +03:00
parent 35fe1c418c
commit e59d8bdac3
2 changed files with 28 additions and 8 deletions

View File

@ -464,7 +464,15 @@ class _Arguments(object):
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
if k.startswith(prefix)])
# noinspection PyProtectedMember
parameters_type = {
k: p.type
for k, p in ((self._task._get_task_property('hyperparams', raise_on_error=False) or {}).get(
prefix[:-len(self._prefix_sep)]) or {}).items()
if p.type
}
else:
parameters_type = {}
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
if not k.startswith(self._prefix_tf_defines)])
@ -478,10 +486,20 @@ class _Arguments(object):
param = parameters.get(k, None)
if param is None:
continue
v_type = type(v)
# if default value is not specified, allow casting based on what we have on the Task
if v is not None:
v_type = type(v)
elif parameters_type.get(k):
v_type_str = parameters_type.get(k)
v_type = next((t for t in (bool, int, float, str, list, tuple) if t.__name__ == v_type_str), str)
else:
# this will be type(None), we deal with it later
v_type = type(v)
# assume more general purpose type int -> float
if v_type == int:
if int(v) != float(v):
if v is not None and int(v) != float(v):
v_type = float
elif v_type == bool:
# cast based on string or int

View File

@ -1016,9 +1016,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
parameters.update(new_parameters)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: stringify(v) for k, v in parameters.items()}
if use_hyperparams:
# build nested dict from flat parameters dict:
org_hyperparams = self.data.hyperparams or {}
@ -1032,7 +1029,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
section = hyperparams.get(legacy_name, dict())
section[k] = copy(org_legacy_section[k])
section[k].value = str(v) if v else v
section[k].value = stringify(v)
description = descriptions.get(k)
if description:
section[k].description = description
@ -1045,13 +1042,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
section_name, key = k.split('/', 1)
section = hyperparams.get(section_name, dict())
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
param_type = params_types[org_k] if org_k in params_types else org_param.type
param_type = params_types[org_k] if org_k in params_types else (
org_param.type or (type(v) if v is not None else None)
)
if param_type and not isinstance(param_type, str):
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
section[key] = tasks.ParamsItem(
section=section_name, name=key,
value=str(v) if v else v,
value=stringify(v),
description=descriptions[org_k] if org_k in descriptions else org_param.description,
type=param_type,
)
@ -1060,6 +1059,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(hyperparams=hyperparams)
self.data.hyperparams = hyperparams
else:
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: stringify(v) for k, v in parameters.items()}
execution = self.data.execution
if execution is None:
execution = tasks.Execution(