From e55d113258315e691babefd5895d06aa45e67fa0 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 11 Nov 2022 15:49:24 +0200 Subject: [PATCH] Add support for connecting Enum values as parameters --- clearml/backend_interface/task/args.py | 30 ++++++++++++++--------- clearml/backend_interface/task/task.py | 33 +++++++++++++++++++------- examples/reporting/hyper_parameters.py | 14 +++++++++++ 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/clearml/backend_interface/task/args.py b/clearml/backend_interface/task/args.py index e5a27d30..59aec84e 100644 --- a/clearml/backend_interface/task/args.py +++ b/clearml/backend_interface/task/args.py @@ -1,5 +1,6 @@ import yaml +from enum import Enum from inspect import isfunction from six import PY2 from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa @@ -521,6 +522,11 @@ class _Arguments(object): # this will be type(None), we deal with it later v_type = type(v) + def warn_failed_parsing(): + self._task.log.warning( + "Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v) + ) + # assume more general purpose type int -> float if v_type == int: if v is not None and int(v) != float(v): @@ -533,8 +539,7 @@ class _Arguments(object): try: param = str(param).lower().strip() == 'true' except ValueError: - self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % - (str(k), str(param), str(k), str(v))) + warn_failed_parsing() continue elif v_type == list: # noinspection PyBroadException @@ -542,8 +547,7 @@ class _Arguments(object): p = str(param).strip() param = yaml.load(p, Loader=FloatSafeLoader) except Exception: - self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % - (str(k), str(param), str(k), str(v))) + warn_failed_parsing() continue elif v_type == tuple: # noinspection PyBroadException @@ -551,8 +555,7 @@ class _Arguments(object): p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1] param = tuple(yaml.load(p, Loader=FloatSafeLoader)) except Exception: - self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % - (str(k), str(param), str(k), str(v))) + warn_failed_parsing() continue elif v_type == dict: # noinspection PyBroadException @@ -560,8 +563,14 @@ class _Arguments(object): p = str(param).strip() param = yaml.load(p, Loader=FloatSafeLoader) except Exception: - self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % - (str(k), str(param), str(k), str(v))) + warn_failed_parsing() + elif issubclass(v_type, Enum): + # noinspection PyBroadException + try: + param = getattr(v_type, param).value + except Exception: + warn_failed_parsing() + continue # noinspection PyBroadException try: @@ -573,8 +582,7 @@ class _Arguments(object): else: dictionary[k] = None if param == '' else v_type(param) except Exception: - self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % - (str(k), str(param), str(k), str(v))) + warn_failed_parsing() continue # add missing parameters to dictionary # for k, v in parameters.items(): @@ -593,7 +601,7 @@ class _Arguments(object): :param as_str: if True return string cast of the types :return: List of type objects supported for auto casting (serializing to string) """ - supported_types = (int, float, bool, str, list, tuple) + supported_types = (int, float, bool, str, list, tuple, Enum) if as_str: return tuple([str(t) for t in supported_types]) diff --git a/clearml/backend_interface/task/task.py b/clearml/backend_interface/task/task.py index 070887c4..bca1e3fd 100644 --- a/clearml/backend_interface/task/task.py +++ b/clearml/backend_interface/task/task.py @@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._curr_label_stats = {} self._raise_on_validation_errors = raise_on_validation_errors self._parameters_allowed_types = tuple(set( - six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None)) + six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None), Enum) )) self._app_server = None self._files_server = None @@ -1056,6 +1056,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): except TypeError: pass + if isinstance(value, Enum): + # remove the class name + return str_value.partition(".")[2] + return str_value if not all(isinstance(x, (dict, Iterable)) for x in args): @@ -1081,11 +1085,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): } if not_allowed: self.log.warning( - "Skipping parameter: {}, only builtin types are supported ({})".format( - ', '.join('%s[%s]' % p for p in not_allowed.items()), - ', '.join(t.__name__ for t in self._parameters_allowed_types)) + "Parameters must be of builtin type ({})".format( + ", ".join("%s[%s]" % p for p in not_allowed.items()), + ) ) - new_parameters = {k: v for k, v in new_parameters.items() if k not in not_allowed} use_hyperparams = Session.check_min_api_version('2.9') @@ -1135,12 +1138,24 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if param_type and not isinstance(param_type, str): param_type = param_type.__name__ if hasattr(param_type, '__name__') else str(param_type) + def create_description(): + if org_param: + return org_param.description + created_description = "" + if org_k in descriptions: + created_description = descriptions[org_k] + if isinstance(v, Enum): + # append enum values to description + created_description += "Values:\n" + ",\n".join( + [enum_key for enum_key in type(v).__dict__.keys() if not enum_key.startswith("_")] + ) + return created_description + section[key] = tasks.ParamsItem( - section=section_name, name=key, + section=section_name, + name=key, value=stringify(v), - description=descriptions[org_k] if org_k in descriptions else ( - org_param.description if org_param is not None else None - ), + description=create_description(), type=param_type, ) hyperparams[section_name] = section diff --git a/examples/reporting/hyper_parameters.py b/examples/reporting/hyper_parameters.py index 06b4143d..8be4b7cf 100644 --- a/examples/reporting/hyper_parameters.py +++ b/examples/reporting/hyper_parameters.py @@ -6,10 +6,22 @@ from __future__ import print_function import sys from argparse import ArgumentParser +from enum import Enum from clearml import Task + +class StringEnumClass(Enum): + A = 'a' + B = 'b' + + +class IntEnumClass(Enum): + C = 1 + D = 2 + + # Connecting ClearML with the current process, # from here on everything is logged automatically task = Task.init(project_name='examples', task_name='Hyper-parameters example') @@ -21,6 +33,8 @@ parameters = { 'int': 3, 'float': 2.2, 'string': 'my string', + 'IntEnumParam': StringEnumClass.A, + 'StringEnumParam': IntEnumClass.C } parameters = task.connect(parameters)