mirror of
https://github.com/clearml/clearml
synced 2025-04-09 23:24:31 +00:00
Add support for connecting Enum values as parameters
This commit is contained in:
parent
d0db6ea919
commit
e55d113258
@ -1,5 +1,6 @@
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from six import PY2
|
from six import PY2
|
||||||
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa
|
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa
|
||||||
@ -521,6 +522,11 @@ class _Arguments(object):
|
|||||||
# this will be type(None), we deal with it later
|
# this will be type(None), we deal with it later
|
||||||
v_type = type(v)
|
v_type = type(v)
|
||||||
|
|
||||||
|
def warn_failed_parsing():
|
||||||
|
self._task.log.warning(
|
||||||
|
"Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v)
|
||||||
|
)
|
||||||
|
|
||||||
# assume more general purpose type int -> float
|
# assume more general purpose type int -> float
|
||||||
if v_type == int:
|
if v_type == int:
|
||||||
if v is not None and int(v) != float(v):
|
if v is not None and int(v) != float(v):
|
||||||
@ -533,8 +539,7 @@ class _Arguments(object):
|
|||||||
try:
|
try:
|
||||||
param = str(param).lower().strip() == 'true'
|
param = str(param).lower().strip() == 'true'
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
warn_failed_parsing()
|
||||||
(str(k), str(param), str(k), str(v)))
|
|
||||||
continue
|
continue
|
||||||
elif v_type == list:
|
elif v_type == list:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -542,8 +547,7 @@ class _Arguments(object):
|
|||||||
p = str(param).strip()
|
p = str(param).strip()
|
||||||
param = yaml.load(p, Loader=FloatSafeLoader)
|
param = yaml.load(p, Loader=FloatSafeLoader)
|
||||||
except Exception:
|
except Exception:
|
||||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
warn_failed_parsing()
|
||||||
(str(k), str(param), str(k), str(v)))
|
|
||||||
continue
|
continue
|
||||||
elif v_type == tuple:
|
elif v_type == tuple:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -551,8 +555,7 @@ class _Arguments(object):
|
|||||||
p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1]
|
p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1]
|
||||||
param = tuple(yaml.load(p, Loader=FloatSafeLoader))
|
param = tuple(yaml.load(p, Loader=FloatSafeLoader))
|
||||||
except Exception:
|
except Exception:
|
||||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
warn_failed_parsing()
|
||||||
(str(k), str(param), str(k), str(v)))
|
|
||||||
continue
|
continue
|
||||||
elif v_type == dict:
|
elif v_type == dict:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -560,8 +563,14 @@ class _Arguments(object):
|
|||||||
p = str(param).strip()
|
p = str(param).strip()
|
||||||
param = yaml.load(p, Loader=FloatSafeLoader)
|
param = yaml.load(p, Loader=FloatSafeLoader)
|
||||||
except Exception:
|
except Exception:
|
||||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
warn_failed_parsing()
|
||||||
(str(k), str(param), str(k), str(v)))
|
elif issubclass(v_type, Enum):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
param = getattr(v_type, param).value
|
||||||
|
except Exception:
|
||||||
|
warn_failed_parsing()
|
||||||
|
continue
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -573,8 +582,7 @@ class _Arguments(object):
|
|||||||
else:
|
else:
|
||||||
dictionary[k] = None if param == '' else v_type(param)
|
dictionary[k] = None if param == '' else v_type(param)
|
||||||
except Exception:
|
except Exception:
|
||||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
warn_failed_parsing()
|
||||||
(str(k), str(param), str(k), str(v)))
|
|
||||||
continue
|
continue
|
||||||
# add missing parameters to dictionary
|
# add missing parameters to dictionary
|
||||||
# for k, v in parameters.items():
|
# for k, v in parameters.items():
|
||||||
@ -593,7 +601,7 @@ class _Arguments(object):
|
|||||||
:param as_str: if True return string cast of the types
|
:param as_str: if True return string cast of the types
|
||||||
:return: List of type objects supported for auto casting (serializing to string)
|
:return: List of type objects supported for auto casting (serializing to string)
|
||||||
"""
|
"""
|
||||||
supported_types = (int, float, bool, str, list, tuple)
|
supported_types = (int, float, bool, str, list, tuple, Enum)
|
||||||
if as_str:
|
if as_str:
|
||||||
return tuple([str(t) for t in supported_types])
|
return tuple([str(t) for t in supported_types])
|
||||||
|
|
||||||
|
@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
self._curr_label_stats = {}
|
self._curr_label_stats = {}
|
||||||
self._raise_on_validation_errors = raise_on_validation_errors
|
self._raise_on_validation_errors = raise_on_validation_errors
|
||||||
self._parameters_allowed_types = tuple(set(
|
self._parameters_allowed_types = tuple(set(
|
||||||
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
|
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None), Enum)
|
||||||
))
|
))
|
||||||
self._app_server = None
|
self._app_server = None
|
||||||
self._files_server = None
|
self._files_server = None
|
||||||
@ -1056,6 +1056,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
# remove the class name
|
||||||
|
return str_value.partition(".")[2]
|
||||||
|
|
||||||
return str_value
|
return str_value
|
||||||
|
|
||||||
if not all(isinstance(x, (dict, Iterable)) for x in args):
|
if not all(isinstance(x, (dict, Iterable)) for x in args):
|
||||||
@ -1081,11 +1085,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
}
|
}
|
||||||
if not_allowed:
|
if not_allowed:
|
||||||
self.log.warning(
|
self.log.warning(
|
||||||
"Skipping parameter: {}, only builtin types are supported ({})".format(
|
"Parameters must be of builtin type ({})".format(
|
||||||
', '.join('%s[%s]' % p for p in not_allowed.items()),
|
", ".join("%s[%s]" % p for p in not_allowed.items()),
|
||||||
', '.join(t.__name__ for t in self._parameters_allowed_types))
|
)
|
||||||
)
|
)
|
||||||
new_parameters = {k: v for k, v in new_parameters.items() if k not in not_allowed}
|
|
||||||
|
|
||||||
use_hyperparams = Session.check_min_api_version('2.9')
|
use_hyperparams = Session.check_min_api_version('2.9')
|
||||||
|
|
||||||
@ -1135,12 +1138,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if param_type and not isinstance(param_type, str):
|
if param_type and not isinstance(param_type, str):
|
||||||
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
|
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
|
||||||
|
|
||||||
|
def create_description():
|
||||||
|
if org_param:
|
||||||
|
return org_param.description
|
||||||
|
created_description = ""
|
||||||
|
if org_k in descriptions:
|
||||||
|
created_description = descriptions[org_k]
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
# append enum values to description
|
||||||
|
created_description += "Values:\n" + ",\n".join(
|
||||||
|
[enum_key for enum_key in type(v).__dict__.keys() if not enum_key.startswith("_")]
|
||||||
|
)
|
||||||
|
return created_description
|
||||||
|
|
||||||
section[key] = tasks.ParamsItem(
|
section[key] = tasks.ParamsItem(
|
||||||
section=section_name, name=key,
|
section=section_name,
|
||||||
|
name=key,
|
||||||
value=stringify(v),
|
value=stringify(v),
|
||||||
description=descriptions[org_k] if org_k in descriptions else (
|
description=create_description(),
|
||||||
org_param.description if org_param is not None else None
|
|
||||||
),
|
|
||||||
type=param_type,
|
type=param_type,
|
||||||
)
|
)
|
||||||
hyperparams[section_name] = section
|
hyperparams[section_name] = section
|
||||||
|
@ -6,10 +6,22 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
|
|
||||||
|
|
||||||
|
class StringEnumClass(Enum):
|
||||||
|
A = 'a'
|
||||||
|
B = 'b'
|
||||||
|
|
||||||
|
|
||||||
|
class IntEnumClass(Enum):
|
||||||
|
C = 1
|
||||||
|
D = 2
|
||||||
|
|
||||||
|
|
||||||
# Connecting ClearML with the current process,
|
# Connecting ClearML with the current process,
|
||||||
# from here on everything is logged automatically
|
# from here on everything is logged automatically
|
||||||
task = Task.init(project_name='examples', task_name='Hyper-parameters example')
|
task = Task.init(project_name='examples', task_name='Hyper-parameters example')
|
||||||
@ -21,6 +33,8 @@ parameters = {
|
|||||||
'int': 3,
|
'int': 3,
|
||||||
'float': 2.2,
|
'float': 2.2,
|
||||||
'string': 'my string',
|
'string': 'my string',
|
||||||
|
'IntEnumParam': StringEnumClass.A,
|
||||||
|
'StringEnumParam': IntEnumClass.C
|
||||||
}
|
}
|
||||||
parameters = task.connect(parameters)
|
parameters = task.connect(parameters)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user