diff --git a/examples/absl_example.py b/examples/absl_example.py index be52b01e..c7d918ef 100644 --- a/examples/absl_example.py +++ b/examples/absl_example.py @@ -27,6 +27,7 @@ flags.DEFINE_string('echo5', '5', 'Text to echo.', module_name='test') parameters = { 'list': [1, 2, 3], 'dict': {'a': 1, 'b': 2}, + 'tuple': (1, 2, 3), 'int': 3, 'float': 2.2, 'string': 'my string', @@ -38,7 +39,7 @@ parameters['new_param'] = 'this is new' # changing the value of a parameter (new value will be stored instead of previous one) parameters['float'] = '9.9' - +print(parameters) def main(_): print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr) diff --git a/trains/backend_interface/task/args.py b/trains/backend_interface/task/args.py index 0da38724..8307a2c4 100644 --- a/trains/backend_interface/task/args.py +++ b/trains/backend_interface/task/args.py @@ -296,6 +296,14 @@ class _Arguments(object): self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % (str(k), str(param), str(k), str(v))) continue + elif v_type == tuple: + try: + p = str(param).strip().replace('(', '[', 1)[::-1].replace(')', ']', 1)[::-1] + param = tuple(yaml.load(p, Loader=yaml.SafeLoader)) + except Exception: + self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' % + (str(k), str(param), str(k), str(v))) + continue elif v_type == dict: try: p = str(param).strip() diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 955be855..b8db2856 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -3,7 +3,8 @@ import collections import itertools import logging from enum import Enum -from threading import RLock, Thread +from threading import Thread +from multiprocessing import RLock import six from six.moves.urllib.parse import quote @@ -86,7 +87,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._curr_label_stats = {} self._raise_on_validation_errors = raise_on_validation_errors self._parameters_allowed_types = ( - six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None)) + six.string_types + six.integer_types + (six.text_type, float, list, tuple, dict, type(None)) ) self._app_server = None self._files_server = None