mirror of
https://github.com/clearml/clearml
synced 2025-04-18 13:24:41 +00:00
Fix task.connect(dict) value casting, if None is the default value, use backend stored type
This commit is contained in:
parent
35fe1c418c
commit
e59d8bdac3
@ -464,7 +464,15 @@ class _Arguments(object):
|
|||||||
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
|
prefix = prefix.strip(self._prefix_sep) + self._prefix_sep
|
||||||
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
||||||
if k.startswith(prefix)])
|
if k.startswith(prefix)])
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
parameters_type = {
|
||||||
|
k: p.type
|
||||||
|
for k, p in ((self._task._get_task_property('hyperparams', raise_on_error=False) or {}).get(
|
||||||
|
prefix[:-len(self._prefix_sep)]) or {}).items()
|
||||||
|
if p.type
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
|
parameters_type = {}
|
||||||
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
|
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
|
||||||
if not k.startswith(self._prefix_tf_defines)])
|
if not k.startswith(self._prefix_tf_defines)])
|
||||||
|
|
||||||
@ -478,10 +486,20 @@ class _Arguments(object):
|
|||||||
param = parameters.get(k, None)
|
param = parameters.get(k, None)
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
v_type = type(v)
|
|
||||||
|
# if default value is not specified, allow casting based on what we have on the Task
|
||||||
|
if v is not None:
|
||||||
|
v_type = type(v)
|
||||||
|
elif parameters_type.get(k):
|
||||||
|
v_type_str = parameters_type.get(k)
|
||||||
|
v_type = next((t for t in (bool, int, float, str, list, tuple) if t.__name__ == v_type_str), str)
|
||||||
|
else:
|
||||||
|
# this will be type(None), we deal with it later
|
||||||
|
v_type = type(v)
|
||||||
|
|
||||||
# assume more general purpose type int -> float
|
# assume more general purpose type int -> float
|
||||||
if v_type == int:
|
if v_type == int:
|
||||||
if int(v) != float(v):
|
if v is not None and int(v) != float(v):
|
||||||
v_type = float
|
v_type = float
|
||||||
elif v_type == bool:
|
elif v_type == bool:
|
||||||
# cast based on string or int
|
# cast based on string or int
|
||||||
|
@ -1016,9 +1016,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
|
|
||||||
parameters.update(new_parameters)
|
parameters.update(new_parameters)
|
||||||
|
|
||||||
# force cast all variables to strings (so that we can later edit them in UI)
|
|
||||||
parameters = {k: stringify(v) for k, v in parameters.items()}
|
|
||||||
|
|
||||||
if use_hyperparams:
|
if use_hyperparams:
|
||||||
# build nested dict from flat parameters dict:
|
# build nested dict from flat parameters dict:
|
||||||
org_hyperparams = self.data.hyperparams or {}
|
org_hyperparams = self.data.hyperparams or {}
|
||||||
@ -1032,7 +1029,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
|
if org_legacy_section.get(k, tasks.ParamsItem()).type == 'legacy':
|
||||||
section = hyperparams.get(legacy_name, dict())
|
section = hyperparams.get(legacy_name, dict())
|
||||||
section[k] = copy(org_legacy_section[k])
|
section[k] = copy(org_legacy_section[k])
|
||||||
section[k].value = str(v) if v else v
|
section[k].value = stringify(v)
|
||||||
description = descriptions.get(k)
|
description = descriptions.get(k)
|
||||||
if description:
|
if description:
|
||||||
section[k].description = description
|
section[k].description = description
|
||||||
@ -1045,13 +1042,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
section_name, key = k.split('/', 1)
|
section_name, key = k.split('/', 1)
|
||||||
section = hyperparams.get(section_name, dict())
|
section = hyperparams.get(section_name, dict())
|
||||||
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
|
org_param = org_hyperparams.get(section_name, dict()).get(key, tasks.ParamsItem())
|
||||||
param_type = params_types[org_k] if org_k in params_types else org_param.type
|
param_type = params_types[org_k] if org_k in params_types else (
|
||||||
|
org_param.type or (type(v) if v is not None else None)
|
||||||
|
)
|
||||||
if param_type and not isinstance(param_type, str):
|
if param_type and not isinstance(param_type, str):
|
||||||
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
|
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
|
||||||
|
|
||||||
section[key] = tasks.ParamsItem(
|
section[key] = tasks.ParamsItem(
|
||||||
section=section_name, name=key,
|
section=section_name, name=key,
|
||||||
value=str(v) if v else v,
|
value=stringify(v),
|
||||||
description=descriptions[org_k] if org_k in descriptions else org_param.description,
|
description=descriptions[org_k] if org_k in descriptions else org_param.description,
|
||||||
type=param_type,
|
type=param_type,
|
||||||
)
|
)
|
||||||
@ -1060,6 +1059,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self._edit(hyperparams=hyperparams)
|
self._edit(hyperparams=hyperparams)
|
||||||
self.data.hyperparams = hyperparams
|
self.data.hyperparams = hyperparams
|
||||||
else:
|
else:
|
||||||
|
# force cast all variables to strings (so that we can later edit them in UI)
|
||||||
|
parameters = {k: stringify(v) for k, v in parameters.items()}
|
||||||
|
|
||||||
execution = self.data.execution
|
execution = self.data.execution
|
||||||
if execution is None:
|
if execution is None:
|
||||||
execution = tasks.Execution(
|
execution = tasks.Execution(
|
||||||
|
Loading…
Reference in New Issue
Block a user