Add support for tuples in hyper-parameters dict

This commit is contained in:
allegroai 2019-10-04 01:31:26 +03:00
parent a7eb8476ce
commit 4a42345561
3 changed files with 13 additions and 3 deletions

View File

@ -27,6 +27,7 @@ flags.DEFINE_string('echo5', '5', 'Text to echo.', module_name='test')
parameters = {
'list': [1, 2, 3],
'dict': {'a': 1, 'b': 2},
'tuple': (1, 2, 3),
'int': 3,
'float': 2.2,
'string': 'my string',
@ -38,7 +39,7 @@ parameters['new_param'] = 'this is new'
# changing the value of a parameter (new value will be stored instead of previous one)
parameters['float'] = '9.9'
print(parameters)
def main(_):
print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr)

View File

@ -296,6 +296,14 @@ class _Arguments(object):
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
elif v_type == tuple:
try:
p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1]
param = tuple(yaml.load(p, Loader=yaml.SafeLoader))
except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
elif v_type == dict:
try:
p = str(param).strip()

View File

@ -3,7 +3,8 @@ import collections
import itertools
import logging
from enum import Enum
from threading import RLock, Thread
from threading import Thread
from multiprocessing import RLock
import six
from six.moves.urllib.parse import quote
@ -86,7 +87,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._curr_label_stats = {}
self._raise_on_validation_errors = raise_on_validation_errors
self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
)
self._app_server = None
self._files_server = None