From 8972c1f00590a3b7d32aeb381c7118b79594e932 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Thu, 20 Feb 2020 18:32:12 +0200 Subject: [PATCH] Add Task.[get/set]_parameters_as_dict() to allow interaction with non-main task parameters (no need to connect()) --- trains/task.py | 17 ++++++++++++++++- trains/utilities/proxy_object.py | 27 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/trains/task.py b/trains/task.py index 5c0411b2..9ca000f5 100644 --- a/trains/task.py +++ b/trains/task.py @@ -48,7 +48,7 @@ from .utilities.args import argparser_parseargs_called, get_argparser_last_args, argparser_update_currenttask from .utilities.dicts import ReadOnlyDict from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \ - nested_from_flat_dictionary + nested_from_flat_dictionary, naive_nested_from_flat_dictionary from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic @@ -878,6 +878,21 @@ class Task(_Task): j['variant'], {'last': j['value'], 'min': j['min_value'], 'max': j['max_value']}) return scalar_metrics + def get_parameters_as_dict(self): + """ + Get task parameters as a raw nested dict + Note that values are not parsed and returned as is (i.e. string) + """ + return naive_nested_from_flat_dictionary(self.get_parameters()) + + def set_parameters_as_dict(self, dictionary): + """ + Set task parameters from a (possibly nested) dict. + While parameters are set just as they would be in connect(dict), this does not link the dict to the task, + but rather does a one-time update. + """ + self._arguments.copy_from_dict(flatten_dictionary(dictionary)) + @classmethod def set_credentials(cls, api_host=None, web_host=None, files_host=None, key=None, secret=None, host=None): """ diff --git a/trains/utilities/proxy_object.py b/trains/utilities/proxy_object.py index d1cadd01..11f069d8 100644 --- a/trains/utilities/proxy_object.py +++ b/trains/utilities/proxy_object.py @@ -1,3 +1,5 @@ +import itertools + import six @@ -110,3 +112,28 @@ def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''): # leave it as is, we have nothing to do with it. a_dict[k] = flat_dict.get(prefix+k, v) return a_dict + + +def naive_nested_from_flat_dictionary(flat_dict, sep='/'): + """ A naive conversion of a flat dictionary with '/'-separated keys signifying nesting + into a nested dictionary. + """ + return { + sub_prefix: ( + bucket[0][1] if (len(bucket) == 1 and sub_prefix == bucket[0][0]) + else naive_nested_from_flat_dictionary( + { + k[len(sub_prefix)+1:]: v + for k, v in bucket + if len(k) > len(sub_prefix) + } + ) + ) + for sub_prefix, bucket in ( + (key, list(group)) + for key, group in itertools.groupby( + sorted(flat_dict.items()), + key=lambda item: item[0].partition(sep)[0] + ) + ) + }