From cc894dc1d64b474a97ad99de1968b2780cdb5b8a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 27 Nov 2020 23:24:45 +0200 Subject: [PATCH] Add Task get_configuration_object/set_configuration_object for easier automation --- trains/backend_interface/task/task.py | 47 +++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index ca517214..4d13fa0e 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -43,13 +43,11 @@ from ..util import ( make_message, get_or_create_project, get_single_result, exact_match_regex, mutually_exclusive, ) from ...config import ( - get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, + get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, running_remotely, get_cache_dir, DOCKER_IMAGE_ENV_VAR, get_offline_dir) from ...debugging import get_logger -from ...debugging.log import LoggerRoot from ...storage.helper import StorageHelper, StorageError from .access import AccessMixin -from .log import TaskHandler from .repo import ScriptInfo, pip_freeze from .hyperparams import HyperParams from ...config import config, PROC_MASTER_ID_ENV_VAR, SUPPRESS_UPDATE_MESSAGE_ENV_VAR @@ -747,9 +745,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ Get the parameters for a Task. This method returns a complete group of key-value parameter pairs, but does not support parameter descriptions (the result is a dictionary of key-value pairs). + Notice the returned parameter dict is flat: + i.e. {'Args/param': 'value'} is the argument "param" from section "Args" + :param backwards_compatibility: If True (default) parameters without section name (API version < 2.9, trains-server < 0.16) will be at dict root level. If False, parameters without section name, will be nested under "Args/" key. + :return: dict of the task parameters, all flattened to key/value. Different sections with key prefix "section/" """ @@ -778,6 +780,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not support parameter descriptions (the input is a dictionary of key-value pairs). + Notice the parameter dict is flat: + i.e. {'Args/param': 'value'} will set the argument "param" in section "Args" to "value" :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are merged into a single key-value pair dictionary. @@ -969,6 +973,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ Update the parameters for a Task. This method updates a complete group of key-value parameter pairs, but does not support parameter descriptions (the input is a dictionary of key-value pairs). + Notice the parameter dict is flat: + i.e. {'Args/param': 'value'} will set the argument "param" in section "Args" to "value" :param args: Positional arguments, which are one or more dictionary or (key, value) iterable. They are merged into a single key-value pair dictionary. @@ -1356,6 +1362,33 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): lines = [r.get('msg', '') for r in response.response_data['events']] return lines + def get_configuration_object(self, name): + # type: (str) -> Optional[str] + """ + Get the Task's configuration object section as a blob of text + Use only for automation (externally), otherwise use `Task.connect_configuration`. + + :param str name: Configuration section name + :return: The Task's configuration as a text blob (unconstrained text string) + return None if configuration name is not valid + """ + return self._get_configuration_text(name) + + def set_configuration_object(self, name, config_text=None, description=None, config_type=None): + # type: (str, Optional[str], Optional[str], Optional[str]) -> None + """ + Set the Task's configuration object as a blob of text. + Use only for automation (externally), otherwise use `Task.connect_configuration`. + + :param str name: Configuration section name + :param config_text: configuration as a blob of text (unconstrained text string) + usually the content of a configuration file of a sort + :param str description: Configuration section description + :param str config_type: Optional configuration format type + """ + return self._set_configuration( + name=name, description=description, config_type=config_type, config_text=config_text) + @classmethod def get_projects(cls): # type: () -> (List['projects.Project']) @@ -1574,13 +1607,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :param str name: Configuration name. :param str description: Configuration section description. + :param str config_type: Optional configuration format type (str). :param config_text: model configuration (unconstrained text string). usually the content of a configuration file. If `config_text` is not None, `config_dict` must not be provided. :param config_dict: model configuration parameters dictionary. If `config_dict` is not None, `config_text` must not be provided. """ # make sure we have wither dict or text - mutually_exclusive(config_dict=config_dict, config_text=config_text) + mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True) if not Session.check_min_api_version('2.9'): raise ValueError("Multiple configurations is not supported with the current 'trains-server', " @@ -1588,12 +1622,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if description: description = str(description) - config = config_dict_to_text(config_text or config_dict) + # support empty string + a_config = config_dict_to_text(config_dict if config_text is None else config_text) with self._edit_lock: self.reload() configuration = self.data.configuration or {} configuration[name] = tasks.ConfigurationItem( - name=name, value=config, description=description or None, type=config_type or None) + name=name, value=a_config, description=description or None, type=config_type or None) self._edit(configuration=configuration) def _get_configuration_text(self, name):