mirror of
https://github.com/clearml/clearml
synced 2025-03-03 10:42:00 +00:00
Add populate now stores function arg types as part of the hyper paremeters
This commit is contained in:
parent
3f9f3aedc7
commit
2c214e9848
@ -5,6 +5,8 @@ from six import PY2
|
||||
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, _AppendAction, SUPPRESS # noqa
|
||||
from copy import copy
|
||||
|
||||
from typing import Tuple, Type, Union
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...binding.args import call_original_argparser
|
||||
|
||||
@ -71,7 +73,7 @@ class _Arguments(object):
|
||||
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
||||
if Session.check_min_api_version('2.9'):
|
||||
name = self._prefix_args + name
|
||||
self._task.set_parameter(name=name, value=default, description=help)
|
||||
self._task.set_parameter(name=name, value=default, description=help, value_type=type)
|
||||
|
||||
def connect(self, parser):
|
||||
self._task.connect_argparse(parser)
|
||||
@ -562,6 +564,20 @@ class _Arguments(object):
|
||||
return self._ProxyDictReadOnly(self, prefix, **dictionary)
|
||||
return dictionary
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls, as_str=False):
|
||||
# type: (bool) -> Union[Type, Tuple[str]]
|
||||
"""
|
||||
Return the basic types supported by Argument casting
|
||||
:param as_str: if True return string cast of the types
|
||||
:return: List of type objects supported for auto casting (serializing to string)
|
||||
"""
|
||||
supported_types = (int, float, bool, str, list, tuple)
|
||||
if as_str:
|
||||
return tuple([str(t) for t in supported_types])
|
||||
|
||||
return supported_types
|
||||
|
||||
@classmethod
|
||||
def __cast_arg(cls, arg, dtype=None):
|
||||
if arg is None or callable(arg):
|
||||
|
@ -11,6 +11,7 @@ from typing import Optional, Sequence, Union, Tuple, List, Callable, Dict, Any
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from .args import _Arguments
|
||||
from .repo import ScriptInfo
|
||||
from ...task import Task
|
||||
|
||||
@ -593,7 +594,8 @@ if __name__ == '__main__':
|
||||
'function_input_artifacts={}, it must in the format: '
|
||||
'{{"argument": "task_id.artifact_name"}}'.format(function_input_artifacts)
|
||||
)
|
||||
|
||||
inspect_args = None
|
||||
function_kwargs_types = dict()
|
||||
if function_kwargs is None:
|
||||
function_kwargs = dict()
|
||||
inspect_args = inspect.getfullargspec(a_function)
|
||||
@ -614,6 +616,16 @@ if __name__ == '__main__':
|
||||
function_kwargs = {str(k): v for k, v in zip(inspect_defaults_args, inspect_defaults_vals)} \
|
||||
if inspect_defaults_vals else {str(k): None for k in inspect_defaults_args}
|
||||
|
||||
if function_kwargs:
|
||||
if not inspect_args:
|
||||
inspect_args = inspect.getfullargspec(a_function)
|
||||
# inspect_func.annotations[k]
|
||||
if inspect_args.annotations:
|
||||
supported_types = _Arguments.get_supported_types()
|
||||
function_kwargs_types = {
|
||||
str(k): str(inspect_args.annotations[k].__name__) for k in inspect_args.annotations
|
||||
if inspect_args.annotations[k] in supported_types}
|
||||
|
||||
task_template = cls.task_template.format(
|
||||
kwargs_section=cls.kwargs_section,
|
||||
input_artifact_section=cls.input_artifact_section,
|
||||
@ -658,11 +670,12 @@ if __name__ == '__main__':
|
||||
task['script']['working_dir'] = '.'
|
||||
task['hyperparams'] = {
|
||||
cls.kwargs_section: {
|
||||
k: dict(section=cls.kwargs_section, name=k, value=str(v))
|
||||
k: dict(section=cls.kwargs_section, name=k,
|
||||
value=str(v) if v is not None else '', type=function_kwargs_types.get(k, None))
|
||||
for k, v in (function_kwargs or {}).items()
|
||||
},
|
||||
cls.input_artifact_section: {
|
||||
k: dict(section=cls.input_artifact_section, name=k, value=str(v))
|
||||
k: dict(section=cls.input_artifact_section, name=k, value=str(v) if v is not None else '')
|
||||
for k, v in (function_input_artifacts or {}).items()
|
||||
}
|
||||
}
|
||||
@ -676,7 +689,9 @@ if __name__ == '__main__':
|
||||
{'{}/{}'.format(cls.input_artifact_section, k): str(v) for k, v in function_input_artifacts}
|
||||
if function_input_artifacts else {}
|
||||
)
|
||||
task.set_parameters(hyper_parameters)
|
||||
__function_kwargs_types = {'{}/{}'.format(cls.kwargs_section, k): v for k, v in function_kwargs_types} \
|
||||
if function_kwargs_types else None
|
||||
task.set_parameters(hyper_parameters, __parameters_types=__function_kwargs_types)
|
||||
|
||||
return task
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user