Add populate now stores function arg types as part of the hyper paremeters

This commit is contained in:
allegroai 2022-01-10 00:00:05 +02:00
parent 3f9f3aedc7
commit 2c214e9848
2 changed files with 36 additions and 5 deletions

View File

@ -5,6 +5,8 @@ from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa
from copy import copy
from typing import Tuple, Type, Union
from ...backend_api import Session
from ...binding.args import call_original_argparser
@ -71,7 +73,7 @@ class _Arguments(object):
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
if Session.check_min_api_version('2.9'):
name = self._prefix_args + name
self._task.set_parameter(name=name, value=default, description=help)
self._task.set_parameter(name=name, value=default, description=help, value_type=type)
def connect(self, parser):
self._task.connect_argparse(parser)
@ -562,6 +564,20 @@ class _Arguments(object):
return self._ProxyDictReadOnly(self, prefix, **dictionary)
return dictionary
@classmethod
def get_supported_types(cls, as_str=False):
# type: (bool) -> Union[Type, Tuple[str]]
"""
Return the basic types supported by Argument casting
:param as_str: if True return string cast of the types
:return: List of type objects supported for auto casting (serializing to string)
"""
supported_types = (int, float, bool, str, list, tuple)
if as_str:
return tuple([str(t) for t in supported_types])
return supported_types
@classmethod
def __cast_arg(cls, arg, dtype=None):
if arg is None or callable(arg):

View File

@ -11,6 +11,7 @@ from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any
from pathlib2 import Path
from six.moves.urllib.parse import urlparse
from .args import _Arguments
from .repo import ScriptInfo
from ...task import Task
@ -593,7 +594,8 @@ if __name__ == '__main__':
'function_input_artifacts={}, it must in the format: '
'{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts)
)
inspect_args = None
function_kwargs_types = dict()
if function_kwargs is None:
function_kwargs = dict()
inspect_args = inspect.getfullargspec(a_function)
@ -614,6 +616,16 @@ if __name__ == '__main__':
function_kwargs = {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)} \
if inspect_defaults_vals else {str(k): None for k in inspect_defaults_args}
if function_kwargs:
if not inspect_args:
inspect_args = inspect.getfullargspec(a_function)
# inspect_func.annotations[k]
if inspect_args.annotations:
supported_types = _Arguments.get_supported_types()
function_kwargs_types = {
str(k): str(inspect_args.annotations[k].__name__) for k in inspect_args.annotations
if inspect_args.annotations[k] in supported_types}
task_template = cls.task_template.format(
kwargs_section=cls.kwargs_section,
input_artifact_section=cls.input_artifact_section,
@ -658,11 +670,12 @@ if __name__ == '__main__':
task['script']['working_dir'] = '.'
task['hyperparams'] = {
cls.kwargs_section: {
k: dict(section=cls.kwargs_section, name=k, value=str(v))
k: dict(section=cls.kwargs_section, name=k,
value=str(v) if v is not None else '', type=function_kwargs_types.get(k, None))
for k, v in (function_kwargs or {}).items()
},
cls.input_artifact_section: {
k: dict(section=cls.input_artifact_section, name=k, value=str(v))
k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else '')
for k, v in (function_input_artifacts or {}).items()
}
}
@ -676,7 +689,9 @@ if __name__ == '__main__':
{'{}/{}'.format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts}
if function_input_artifacts else {}
)
task.set_parameters(hyper_parameters)
__function_kwargs_types = {'{}/{}'.format(cls.kwargs_section, k): v for k, v in function_kwargs_types} \
if function_kwargs_types else None
task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types)
return task