diff --git a/examples/manual_model_config.py b/examples/manual_model_config.py index 007dd130..6406e3ec 100644 --- a/examples/manual_model_config.py +++ b/examples/manual_model_config.py @@ -12,22 +12,32 @@ task = Task.init(project_name='examples', task_name='Manual model configuration' # create a model model = torch.nn.Module -# store dictionary of definition for a specific network design +# Connect a local configuration file +config_file = 'samples/sample.json' +config_file = task.connect_configuration(config_file) +# then read configuration as usual, the backend will contain a copy of it. +# later when executing remotely, the returned `config_file` will be a temporary file +# containing a new copy of the configuration retrieved form the backend +# # model_config_dict = json.load(open(config_file, 'rt')) + +# Or Store dictionary of definition for a specific network design model_config_dict = { 'value': 13.37, - 'dict': {'sub_value': 'string'}, + 'dict': {'sub_value': 'string', 'sub_integer': 11}, 'list_of_ints': [1, 2, 3, 4], } -task.set_model_config(config_dict=model_config_dict) +model_config_dict = task.connect_configuration(model_config_dict) -# or read form a config file (this will override the previous configuration dictionary) -# task.set_model_config(config_text='this is just a blob\nof text from a configuration file') +# We now update the dictionary after connecting it, and the changes will be tracked as well. +model_config_dict['new value'] = 10 +model_config_dict['value'] *= model_config_dict['new value'] -# store the label enumeration the model is training for -task.set_model_label_enumeration({'background': 0, 'cat': 1, 'dog': 2}) -print('Any model stored from this point onwards, will contain both model_config and label_enumeration') +# store the label enumeration of the training model +labels = {'background': 0, 'cat': 1, 'dog': 2} +task.connect_label_enumeration(labels) # storing the model, it will have the task network configuration and label enumeration +print('Any model stored from this point onwards, will contain both model_config and label_enumeration') torch.save(model, os.path.join(gettempdir(), "model")) print('Model saved') diff --git a/examples/samples/sample.json b/examples/samples/sample.json new file mode 100644 index 00000000..21896433 --- /dev/null +++ b/examples/samples/sample.json @@ -0,0 +1,8 @@ +{ + "list_of_ints": [1,2,3,4], + "dict": { + "sub_value": "string", + "sub_integer": 11 + }, + "value": 13.37 +} diff --git a/trains/task.py b/trains/task.py index 2c4ca443..6c8bd569 100644 --- a/trains/task.py +++ b/trains/task.py @@ -5,6 +5,9 @@ import sys import threading import time from argparse import ArgumentParser +from tempfile import mkstemp + +from pathlib2 import Path from collections import OrderedDict, Callable from typing import Optional @@ -40,6 +43,8 @@ from .binding.matplotlib_bind import PatchedMatplotlib from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic from .utilities.dicts import ReadOnlyDict +from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \ + nested_from_flat_dictionary class Task(_Task): @@ -354,13 +359,27 @@ class Task(_Task): """ Returns Task object based on either, task_id (system uuid) or task name - :param task_id: unique task id string (if exists other parameters are ignored) - :param project_name: project name (str) the task belongs to - :param task_name: task name (str) in within the selected project - :return: Task() object + :param str task_id: unique task id string (if exists other parameters are ignored) + :param str project_name: project name (str) the task belongs to + :param str task_name: task name (str) in within the selected project + :return: Task object """ return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) + @classmethod + def get_tasks(cls, task_ids=None, project_name=None, task_name=None): + """ + Returns a list of Task objects, matching requested task name (or partially matching) + + :param list(str) task_ids: list of unique task id string (if exists other parameters are ignored) + :param str project_name: project name (str) the task belongs to (use None for all projects) + :param str task_name: task name (str) in within the selected project + Return any partial match of task_name, regular expressions matching is also supported + If None is passed, returns all tasks within the project + :return: list of Task object + """ + return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name) + @property def output_uri(self): return self.storage_uri @@ -385,9 +404,12 @@ class Task(_Task): """ if not Session.check_min_api_version('2.3'): return ReadOnlyDict() - if not self.data.execution or not self.data.execution.artifacts: - return ReadOnlyDict() - return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts]) + artifacts_pairs = [] + if self.data.execution and self.data.execution.artifacts: + artifacts_pairs = [(a.key, Artifact(a)) for a in self.data.execution.artifacts] + if self._artifacts_manager: + artifacts_pairs += list(self._artifacts_manager.registered_artifacts.items()) + return ReadOnlyDict(artifacts_pairs) @classmethod def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None): @@ -469,19 +491,6 @@ class Task(_Task): resp = res.response return resp - def set_comment(self, comment): - """ - Set a comment text to the task. - - In remote, this is a no-op. - - :param comment: The comment of the task - :type comment: str - """ - if not running_remotely() or not self.is_main_task(): - self._edit(comment=comment) - self.reload() - def add_tags(self, tags): """ Add tags to this task. Old tags are not deleted @@ -526,6 +535,93 @@ class Task(_Task): raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) + def connect_configuration(self, configuration): + """ + Connect a configuration dict / file (pathlib.Path / str) with the Task + Connecting configuration file should be called before reading the configuration file. + When an output model will be created it will include the content of the configuration dict/file + + Example local file: + config_file = task.connect_configuration(config_file) + my_params = json.load(open(config_file,'rt')) + + Example parameter dictionary: + my_params = task.connect_configuration(my_params) + + :param (dict, pathlib.Path/str) configuration: usually configuration file used in the model training process + configuration can be either dict or path to local file. + If dict is provided, it will be stored in json alike format (hocon) editable in the UI + If pathlib2.Path / string is provided the content of the file will be stored + Notice: local path must be relative path + (and in remote execution, the content of the file will be overwritten with the content brought from the UI) + :return: configuration object + If dict was provided, a dictionary will be returned + If pathlib2.Path / string was provided, a path to a local configuration file is returned + """ + if not isinstance(configuration, (dict, Path, six.string_types)): + raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, " + "{} is not supported".format(type(configuration))) + + # parameter dictionary + if isinstance(configuration, dict): + def _update_config_dict(task, config_dict): + task.set_model_config(config_dict=config_dict) + + if not running_remotely(): + self.set_model_config(config_dict=configuration) + configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration) + else: + configuration.clear() + configuration.update(self.get_model_config_dict()) + configuration = ProxyDictPreWrite(False, False, **configuration) + return configuration + + # it is a path to a local file + if not running_remotely(): + # check if not absolute path + configuration_path = Path(configuration) + if not configuration_path.is_file(): + ValueError("Configuration file does not exist") + try: + with open(configuration_path.as_posix(), 'rt') as f: + configuration_text = f.read() + except Exception: + raise ValueError("Could not connect configuration file {}, file could not be read".format( + configuration_path.as_posix())) + self.set_model_config(config_text=configuration_text) + return configuration + else: + configuration_text = self.get_model_config_text() + configuration_path = Path(configuration) + fd, local_filename = mkstemp(prefix='trains_task_config_', + suffix=configuration_path.suffixes[-1] if + configuration_path.suffixes else '.txt') + os.write(fd, configuration_text.encode('utf-8')) + os.close(fd) + return Path(local_filename) if isinstance(configuration, Path) else local_filename + + def connect_label_enumeration(self, enumeration): + """ + Connect a label enumeration dictionary with the Task + + When an output model is created it will store the model label enumeration dictionary + + :param dict enumeration: dictionary of string to integer, enumerating the model output integer to labels + example: {'background': 0 , 'person': 1} + :return: enumeration dict + """ + if not isinstance(enumeration, dict): + raise ValueError("connect_label_enumeration supports only `dict` type, " + "{} is not supported".format(type(enumeration))) + + if not running_remotely(): + self.set_model_label_enumeration(enumeration) + else: + # pop everything + enumeration.clear() + enumeration.update(self.get_labels_enumeration()) + return enumeration + def get_logger(self): # type: () -> Logger """ @@ -1029,12 +1125,26 @@ class Task(_Task): self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args) def _connect_dictionary(self, dictionary): + def _update_args_dict(task, config_dict): + task._arguments.copy_from_dict(flatten_dictionary(config_dict)) + + def _refresh_args_dict(task, config_dict): + # reread from task including newly added keys + flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict)) + nested_dict = config_dict._to_dict() + config_dict.clear() + config_dict.update(nested_from_flat_dictionary(nested_dict, flat_dict)) + self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary) - if running_remotely(): - dictionary = self._arguments.copy_to_dict(dictionary) + if not running_remotely(): + self._arguments.copy_from_dict(flatten_dictionary(dictionary)) + dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary) else: - dictionary = self._arguments.copy_from_dict(dictionary) + flat_dict = flatten_dictionary(dictionary) + flat_dict = self._arguments.copy_to_dict(flat_dict) + dictionary = nested_from_flat_dictionary(dictionary, flat_dict) + dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary) return dictionary @@ -1408,7 +1518,7 @@ class Task(_Task): cls._get_default_session(), tasks.GetAllRequest( project=[project.id], - name=exact_match_regex(task_name), + name=exact_match_regex(task_name) if task_name else None, only_fields=['id', 'name', 'last_update', system_tags] ) ) @@ -1430,6 +1540,37 @@ class Task(_Task): log_to_backend=False, ) + @classmethod + def __get_tasks(cls, task_ids=None, project_name=None, task_name=None): + if task_ids: + if isinstance(task_ids, six.string_types): + task_ids = [task_ids] + return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids] + + if project_name: + res = cls._send( + cls._get_default_session(), + projects.GetAllRequest( + name=exact_match_regex(project_name) + ) + ) + project = get_single_result(entity='project', query=project_name, results=res.response.projects) + else: + project = None + + system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' + res = cls._send( + cls._get_default_session(), + tasks.GetAllRequest( + project=[project.id] if project else None, + name=task_name if task_name else None, + only_fields=['id', 'name', 'last_update', system_tags] + ) + ) + res_tasks = res.response.tasks + + return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks] + @classmethod def __get_hash_key(cls, *args): def normalize(x): diff --git a/trains/utilities/proxy_object.py b/trains/utilities/proxy_object.py index 2f910f3d..928c7809 100644 --- a/trains/utilities/proxy_object.py +++ b/trains/utilities/proxy_object.py @@ -1,3 +1,4 @@ +import six class ProxyDictPostWrite(dict): @@ -5,11 +6,11 @@ class ProxyDictPostWrite(dict): def __init__(self, update_obj, update_func, *args, **kwargs): super(ProxyDictPostWrite, self).__init__(*args, **kwargs) + self._update_obj = update_obj self._update_func = None for k, i in self.items(): if isinstance(i, dict): - self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)}) - self._update_obj = update_obj + super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)}) self._update_func = update_func def __setitem__(self, key, value): @@ -20,6 +21,20 @@ class ProxyDictPostWrite(dict): if self._update_func: self._update_func(self._update_obj, self) + def _to_dict(self): + a_dict = {} + for k, i in self.items(): + if isinstance(i, ProxyDictPostWrite): + a_dict[k] = i._to_dict() + else: + a_dict[k] = i + return a_dict + + def update(self, E=None, **F): + return super(ProxyDictPostWrite, self).update( + ProxyDictPostWrite(self._update_obj, self._set_callback, **E) if E is not None else + ProxyDictPostWrite(self._update_obj, self._set_callback, **F)) + class ProxyDictPreWrite(dict): """ Dictionary wrapper that prevents modifications to the dictionary """ @@ -39,8 +54,12 @@ class ProxyDictPreWrite(dict): super(ProxyDictPreWrite, self).__setitem__(*key_value) def _set_callback(self, key_value, *_): - if self._update_func: - res = self._update_func(self._update_obj, key_value) + if self._update_func is not None: + if callable(self._update_func): + res = self._update_func(self._update_obj, key_value) + else: + res = self._update_func + if not res: return None return res @@ -48,3 +67,40 @@ class ProxyDictPreWrite(dict): def _nested_callback(self, prefix, key_value): return self._set_callback((prefix+'.'+key_value[0], key_value[1],)) + + +def flatten_dictionary(a_dict, prefix=''): + flat_dict = {} + sep = '/' + basic_types = (float, int, bool, six.string_types, ) + for k, v in a_dict.items(): + k = str(k) + if isinstance(v, (float, int, bool, six.string_types)): + flat_dict[prefix+k] = v + elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]): + flat_dict[prefix+k] = v + elif isinstance(v, dict): + flat_dict.update(flatten_dictionary(v, prefix=prefix+k+sep)) + else: + # this is a mixture of list and dict, or any other object, + # leave it as is, we have nothing to do with it. + flat_dict[prefix+k] = v + return flat_dict + + +def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''): + basic_types = (float, int, bool, six.string_types, ) + sep = '/' + for k, v in a_dict.items(): + k = str(k) + if isinstance(v, (float, int, bool, six.string_types)): + a_dict[k] = flat_dict.get(prefix+k, v) + elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]): + a_dict[k] = flat_dict.get(prefix+k, v) + elif isinstance(v, dict): + a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix+k+sep) or v + else: + # this is a mixture of list and dict, or any other object, + # leave it as is, we have nothing to do with it. + a_dict[k] = flat_dict.get(prefix+k, v) + return a_dict