Add support for tuples in hyper-parameters dict

This commit is contained in:
allegroai 2019-10-04 01:31:26 +03:00
parent a7eb8476ce
commit 4a42345561
3 changed files with 13 additions and 3 deletions

View File

@ -27,6 +27,7 @@ flags.DEFINE_string('echo5', '5', 'Text to echo.', module_name='test')
parameters = { parameters = {
'list': [1, 2, 3], 'list': [1, 2, 3],
'dict': {'a': 1, 'b': 2}, 'dict': {'a': 1, 'b': 2},
'tuple': (1, 2, 3),
'int': 3, 'int': 3,
'float': 2.2, 'float': 2.2,
'string': 'my string', 'string': 'my string',
@ -38,7 +39,7 @@ parameters['new_param'] = 'this is new'
# changing the value of a parameter (new value will be stored instead of previous one) # changing the value of a parameter (new value will be stored instead of previous one)
parameters['float'] = '9.9' parameters['float'] = '9.9'
print(parameters)
def main(_): def main(_):
print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr) print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr)

View File

@ -296,6 +296,14 @@ class _Arguments(object):
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v))) (str(k), str(param), str(k), str(v)))
continue continue
elif v_type == tuple:
try:
p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1]
param = tuple(yaml.load(p, Loader=yaml.SafeLoader))
except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
elif v_type == dict: elif v_type == dict:
try: try:
p = str(param).strip() p = str(param).strip()

View File

@ -3,7 +3,8 @@ import collections
import itertools import itertools
import logging import logging
from enum import Enum from enum import Enum
from threading import RLock, Thread from threading import Thread
from multiprocessing import RLock
import six import six
from six.moves.urllib.parse import quote from six.moves.urllib.parse import quote
@ -86,7 +87,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._curr_label_stats = {} self._curr_label_stats = {}
self._raise_on_validation_errors = raise_on_validation_errors self._raise_on_validation_errors = raise_on_validation_errors
self._parameters_allowed_types = ( self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None))
) )
self._app_server = None self._app_server = None
self._files_server = None self._files_server = None