Add support for connecting Enum values as parameters

This commit is contained in:
allegroai 2022-11-11 15:49:24 +02:00
parent d0db6ea919
commit e55d113258
3 changed files with 57 additions and 20 deletions

View File

@ -1,5 +1,6 @@
import yaml import yaml
from enum import Enum
from inspect import isfunction from inspect import isfunction
from six import PY2 from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa
@ -521,6 +522,11 @@ class _Arguments(object):
# this will be type(None), we deal with it later # this will be type(None), we deal with it later
v_type = type(v) v_type = type(v)
def warn_failed_parsing():
self._task.log.warning(
"Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v)
)
# assume more general purpose type int -> float # assume more general purpose type int -> float
if v_type == int: if v_type == int:
if v is not None and int(v) != float(v): if v is not None and int(v) != float(v):
@ -533,8 +539,7 @@ class _Arguments(object):
try: try:
param = str(param).lower().strip() == 'true' param = str(param).lower().strip() == 'true'
except ValueError: except ValueError:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % warn_failed_parsing()
(str(k), str(param), str(k), str(v)))
continue continue
elif v_type == list: elif v_type == list:
# noinspection PyBroadException # noinspection PyBroadException
@ -542,8 +547,7 @@ class _Arguments(object):
p = str(param).strip() p = str(param).strip()
param = yaml.load(p, Loader=FloatSafeLoader) param = yaml.load(p, Loader=FloatSafeLoader)
except Exception: except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % warn_failed_parsing()
(str(k), str(param), str(k), str(v)))
continue continue
elif v_type == tuple: elif v_type == tuple:
# noinspection PyBroadException # noinspection PyBroadException
@ -551,8 +555,7 @@ class _Arguments(object):
p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1] p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1]
param = tuple(yaml.load(p, Loader=FloatSafeLoader)) param = tuple(yaml.load(p, Loader=FloatSafeLoader))
except Exception: except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % warn_failed_parsing()
(str(k), str(param), str(k), str(v)))
continue continue
elif v_type == dict: elif v_type == dict:
# noinspection PyBroadException # noinspection PyBroadException
@ -560,8 +563,14 @@ class _Arguments(object):
p = str(param).strip() p = str(param).strip()
param = yaml.load(p, Loader=FloatSafeLoader) param = yaml.load(p, Loader=FloatSafeLoader)
except Exception: except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % warn_failed_parsing()
(str(k), str(param), str(k), str(v))) elif issubclass(v_type, Enum):
# noinspection PyBroadException
try:
param = getattr(v_type, param).value
except Exception:
warn_failed_parsing()
continue
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -573,8 +582,7 @@ class _Arguments(object):
else: else:
dictionary[k] = None if param == '' else v_type(param) dictionary[k] = None if param == '' else v_type(param)
except Exception: except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % warn_failed_parsing()
(str(k), str(param), str(k), str(v)))
continue continue
# add missing parameters to dictionary # add missing parameters to dictionary
# for k, v in parameters.items(): # for k, v in parameters.items():
@ -593,7 +601,7 @@ class _Arguments(object):
:param as_str: if True return string cast of the types :param as_str: if True return string cast of the types
:return: List of type objects supported for auto casting (serializing to string) :return: List of type objects supported for auto casting (serializing to string)
""" """
supported_types = (int, float, bool, str, list, tuple) supported_types = (int, float, bool, str, list, tuple, Enum)
if as_str: if as_str:
return tuple([str(t) for t in supported_types]) return tuple([str(t) for t in supported_types])

View File

@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._curr_label_stats = {} self._curr_label_stats = {}
self._raise_on_validation_errors = raise_on_validation_errors self._raise_on_validation_errors = raise_on_validation_errors
self._parameters_allowed_types = tuple(set( self._parameters_allowed_types = tuple(set(
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None)) six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None), Enum)
)) ))
self._app_server = None self._app_server = None
self._files_server = None self._files_server = None
@ -1056,6 +1056,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
except TypeError: except TypeError:
pass pass
if isinstance(value, Enum):
# remove the class name
return str_value.partition(".")[2]
return str_value return str_value
if not all(isinstance(x, (dict, Iterable)) for x in args): if not all(isinstance(x, (dict, Iterable)) for x in args):
@ -1081,11 +1085,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
} }
if not_allowed: if not_allowed:
self.log.warning( self.log.warning(
"Skipping parameter: {}, only builtin types are supported ({})".format( "Parameters must be of builtin type ({})".format(
', '.join('%s[%s]' % p for p in not_allowed.items()), ", ".join("%s[%s]" % p for p in not_allowed.items()),
', '.join(t.__name__ for t in self._parameters_allowed_types)) )
) )
new_parameters = {k: v for k, v in new_parameters.items() if k not in not_allowed}
use_hyperparams = Session.check_min_api_version('2.9') use_hyperparams = Session.check_min_api_version('2.9')
@ -1135,12 +1138,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if param_type and not isinstance(param_type, str): if param_type and not isinstance(param_type, str):
param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type) param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type)
def create_description():
if org_param:
return org_param.description
created_description = ""
if org_k in descriptions:
created_description = descriptions[org_k]
if isinstance(v, Enum):
# append enum values to description
created_description += "Values:\n" + ",\n".join(
[enum_key for enum_key in type(v).__dict__.keys() if not enum_key.startswith("_")]
)
return created_description
section[key] = tasks.ParamsItem( section[key] = tasks.ParamsItem(
section=section_name, name=key, section=section_name,
name=key,
value=stringify(v), value=stringify(v),
description=descriptions[org_k] if org_k in descriptions else ( description=create_description(),
org_param.description if org_param is not None else None
),
type=param_type, type=param_type,
) )
hyperparams[section_name] = section hyperparams[section_name] = section

View File

@ -6,10 +6,22 @@ from __future__ import print_function
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from enum import Enum
from clearml import Task from clearml import Task
class StringEnumClass(Enum):
A = 'a'
B = 'b'
class IntEnumClass(Enum):
C = 1
D = 2
# Connecting ClearML with the current process, # Connecting ClearML with the current process,
# from here on everything is logged automatically # from here on everything is logged automatically
task = Task.init(project_name='examples', task_name='Hyper-parameters example') task = Task.init(project_name='examples', task_name='Hyper-parameters example')
@ -21,6 +33,8 @@ parameters = {
'int': 3, 'int': 3,
'float': 2.2, 'float': 2.2,
'string': 'my string', 'string': 'my string',
'IntEnumParam': StringEnumClass.A,
'StringEnumParam': IntEnumClass.C
} }
parameters = task.connect(parameters) parameters = task.connect(parameters)