From 2c214e98486881e460aa08acde6c11ab3eb693e2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 10 Jan 2022 00:00:05 +0200 Subject: [PATCH] Add populate now stores function arg types as part of the hyper paremeters --- clearml/backend_interface/task/args.py | 18 ++++++++++++++++- clearml/backend_interface/task/populate.py | 23 ++++++++++++++++++---- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/clearml/backend_interface/task/args.py b/clearml/backend_interface/task/args.py index c07f51bb..1f2cc78b 100644 --- a/clearml/backend_interface/task/args.py +++ b/clearml/backend_interface/task/args.py @@ -5,6 +5,8 @@ from six import PY2 from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa from copy import copy +from typing import Tuple, Type, Union + from ...backend_api import Session from ...binding.args import call_original_argparser @@ -71,7 +73,7 @@ class _Arguments(object): name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t') if Session.check_min_api_version('2.9'): name = self._prefix_args + name - self._task.set_parameter(name=name, value=default, description=help) + self._task.set_parameter(name=name, value=default, description=help, value_type=type) def connect(self, parser): self._task.connect_argparse(parser) @@ -562,6 +564,20 @@ class _Arguments(object): return self._ProxyDictReadOnly(self, prefix, **dictionary) return dictionary + @classmethod + def get_supported_types(cls, as_str=False): + # type: (bool) -> Union[Type, Tuple[str]] + """ + Return the basic types supported by Argument casting + :param as_str: if True return string cast of the types + :return: List of type objects supported for auto casting (serializing to string) + """ + supported_types = (int, float, bool, str, list, tuple) + if as_str: + return tuple([str(t) for t in supported_types]) + + return supported_types + @classmethod def __cast_arg(cls, arg, dtype=None): if arg is None or callable(arg): diff --git a/clearml/backend_interface/task/populate.py b/clearml/backend_interface/task/populate.py index 55cb3a69..490bf41b 100644 --- a/clearml/backend_interface/task/populate.py +++ b/clearml/backend_interface/task/populate.py @@ -11,6 +11,7 @@ from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any from pathlib2 import Path from six.moves.urllib.parse import urlparse +from .args import _Arguments from .repo import ScriptInfo from ...task import Task @@ -593,7 +594,8 @@ if __name__ == '__main__': 'function_input_artifacts={}, it must in the format: ' '{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts) ) - + inspect_args = None + function_kwargs_types = dict() if function_kwargs is None: function_kwargs = dict() inspect_args = inspect.getfullargspec(a_function) @@ -614,6 +616,16 @@ if __name__ == '__main__': function_kwargs = {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)} \ if inspect_defaults_vals else {str(k): None for k in inspect_defaults_args} + if function_kwargs: + if not inspect_args: + inspect_args = inspect.getfullargspec(a_function) + # inspect_func.annotations[k] + if inspect_args.annotations: + supported_types = _Arguments.get_supported_types() + function_kwargs_types = { + str(k): str(inspect_args.annotations[k].__name__) for k in inspect_args.annotations + if inspect_args.annotations[k] in supported_types} + task_template = cls.task_template.format( kwargs_section=cls.kwargs_section, input_artifact_section=cls.input_artifact_section, @@ -658,11 +670,12 @@ if __name__ == '__main__': task['script']['working_dir'] = '.' task['hyperparams'] = { cls.kwargs_section: { - k: dict(section=cls.kwargs_section, name=k, value=str(v)) + k: dict(section=cls.kwargs_section, name=k, + value=str(v) if v is not None else '', type=function_kwargs_types.get(k, None)) for k, v in (function_kwargs or {}).items() }, cls.input_artifact_section: { - k: dict(section=cls.input_artifact_section, name=k, value=str(v)) + k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else '') for k, v in (function_input_artifacts or {}).items() } } @@ -676,7 +689,9 @@ if __name__ == '__main__': {'{}/{}'.format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts} if function_input_artifacts else {} ) - task.set_parameters(hyper_parameters) + __function_kwargs_types = {'{}/{}'.format(cls.kwargs_section, k): v for k, v in function_kwargs_types} \ + if function_kwargs_types else None + task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types) return task