Add Task connect_configuration, connect_label_enumeration, and support for nested dictionary with connect

Add Task.get_tasks
This commit is contained in:
allegroai 2019-11-15 22:00:10 +02:00
parent 15683b5b43
commit 0a45d2094f
4 changed files with 251 additions and 36 deletions

View File

@ -12,22 +12,32 @@ task = Task.init(project_name='examples', task_name='Manual model configuration'
# create a model
model = torch.nn.Module
# store dictionary of definition for a specific network design
# Connect a local configuration file
config_file = 'samples/sample.json'
config_file = task.connect_configuration(config_file)
# then read configuration as usual, the backend will contain a copy of it.
# later when executing remotely, the returned `config_file` will be a temporary file
# containing a new copy of the configuration retrieved form the backend
# # model_config_dict = json.load(open(config_file, 'rt'))
# Or Store dictionary of definition for a specific network design
model_config_dict = {
'value': 13.37,
'dict': {'sub_value': 'string'},
'dict': {'sub_value': 'string', 'sub_integer': 11},
'list_of_ints': [1, 2, 3, 4],
}
task.set_model_config(config_dict=model_config_dict)
model_config_dict = task.connect_configuration(model_config_dict)
# or read form a config file (this will override the previous configuration dictionary)
# task.set_model_config(config_text='this is just a blob\nof text from a configuration file')
# We now update the dictionary after connecting it, and the changes will be tracked as well.
model_config_dict['new value'] = 10
model_config_dict['value'] *= model_config_dict['new value']
# store the label enumeration the model is training for
task.set_model_label_enumeration({'background': 0, 'cat': 1, 'dog': 2})
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
# store the label enumeration of the training model
labels = {'background': 0, 'cat': 1, 'dog': 2}
task.connect_label_enumeration(labels)
# storing the model, it will have the task network configuration and label enumeration
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
torch.save(model, os.path.join(gettempdir(), "model"))
print('Model saved')

View File

@ -0,0 +1,8 @@
{
"list_of_ints": [1,2,3,4],
"dict": {
"sub_value": "string",
"sub_integer": 11
},
"value": 13.37
}

View File

@ -5,6 +5,9 @@ import sys
import threading
import time
from argparse import ArgumentParser
from tempfile import mkstemp
from pathlib2 import Path
from collections import OrderedDict, Callable
from typing import Optional
@ -40,6 +43,8 @@ from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.dicts import ReadOnlyDict
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \
nested_from_flat_dictionary
class Task(_Task):
@ -354,13 +359,27 @@ class Task(_Task):
"""
Returns Task object based on either, task_id (system uuid) or task name
:param task_id: unique task id string (if exists other parameters are ignored)
:param project_name: project name (str) the task belongs to
:param task_name: task name (str) in within the selected project
:return: Task() object
:param str task_id: unique task id string (if exists other parameters are ignored)
:param str project_name: project name (str) the task belongs to
:param str task_name: task name (str) in within the selected project
:return: Task object
"""
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
@classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None):
"""
Returns a list of Task objects, matching requested task name (or partially matching)
:param list(str) task_ids: list of unique task id string (if exists other parameters are ignored)
:param str project_name: project name (str) the task belongs to (use None for all projects)
:param str task_name: task name (str) in within the selected project
Return any partial match of task_name, regular expressions matching is also supported
If None is passed, returns all tasks within the project
:return: list of Task object
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name)
@property
def output_uri(self):
return self.storage_uri
@ -385,9 +404,12 @@ class Task(_Task):
"""
if not Session.check_min_api_version('2.3'):
return ReadOnlyDict()
if not self.data.execution or not self.data.execution.artifacts:
return ReadOnlyDict()
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
artifacts_pairs = []
if self.data.execution and self.data.execution.artifacts:
artifacts_pairs = [(a.key, Artifact(a)) for a in self.data.execution.artifacts]
if self._artifacts_manager:
artifacts_pairs += list(self._artifacts_manager.registered_artifacts.items())
return ReadOnlyDict(artifacts_pairs)
@classmethod
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
@ -469,19 +491,6 @@ class Task(_Task):
resp = res.response
return resp
def set_comment(self, comment):
"""
Set a comment text to the task.
In remote, this is a no-op.
:param comment: The comment of the task
:type comment: str
"""
if not running_remotely() or not self.is_main_task():
self._edit(comment=comment)
self.reload()
def add_tags(self, tags):
"""
Add tags to this task. Old tags are not deleted
@ -526,6 +535,93 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def connect_configuration(self, configuration):
"""
Connect a configuration dict / file (pathlib.Path / str) with the Task
Connecting configuration file should be called before reading the configuration file.
When an output model will be created it will include the content of the configuration dict/file
Example local file:
config_file = task.connect_configuration(config_file)
my_params = json.load(open(config_file,'rt'))
Example parameter dictionary:
my_params = task.connect_configuration(my_params)
:param (dict, pathlib.Path/str) configuration: usually configuration file used in the model training process
configuration can be either dict or path to local file.
If dict is provided, it will be stored in json alike format (hocon) editable in the UI
If pathlib2.Path / string is provided the content of the file will be stored
Notice: local path must be relative path
(and in remote execution, the content of the file will be overwritten with the content brought from the UI)
:return: configuration object
If dict was provided, a dictionary will be returned
If pathlib2.Path / string was provided, a path to a local configuration file is returned
"""
if not isinstance(configuration, (dict, Path, six.string_types)):
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
"{} is not supported".format(type(configuration)))
# parameter dictionary
if isinstance(configuration, dict):
def _update_config_dict(task, config_dict):
task.set_model_config(config_dict=config_dict)
if not running_remotely():
self.set_model_config(config_dict=configuration)
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
else:
configuration.clear()
configuration.update(self.get_model_config_dict())
configuration = ProxyDictPreWrite(False, False, **configuration)
return configuration
# it is a path to a local file
if not running_remotely():
# check if not absolute path
configuration_path = Path(configuration)
if not configuration_path.is_file():
ValueError("Configuration file does not exist")
try:
with open(configuration_path.as_posix(), 'rt') as f:
configuration_text = f.read()
except Exception:
raise ValueError("Could not connect configuration file {}, file could not be read".format(
configuration_path.as_posix()))
self.set_model_config(config_text=configuration_text)
return configuration
else:
configuration_text = self.get_model_config_text()
configuration_path = Path(configuration)
fd, local_filename = mkstemp(prefix='trains_task_config_',
suffix=configuration_path.suffixes[-1] if
configuration_path.suffixes else '.txt')
os.write(fd, configuration_text.encode('utf-8'))
os.close(fd)
return Path(local_filename) if isinstance(configuration, Path) else local_filename
def connect_label_enumeration(self, enumeration):
"""
Connect a label enumeration dictionary with the Task
When an output model is created it will store the model label enumeration dictionary
:param dict enumeration: dictionary of string to integer, enumerating the model output integer to labels
example: {'background': 0 , 'person': 1}
:return: enumeration dict
"""
if not isinstance(enumeration, dict):
raise ValueError("connect_label_enumeration supports only `dict` type, "
"{} is not supported".format(type(enumeration)))
if not running_remotely():
self.set_model_label_enumeration(enumeration)
else:
# pop everything
enumeration.clear()
enumeration.update(self.get_labels_enumeration())
return enumeration
def get_logger(self):
# type: () -> Logger
"""
@ -1029,12 +1125,26 @@ class Task(_Task):
self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args)
def _connect_dictionary(self, dictionary):
def _update_args_dict(task, config_dict):
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
def _refresh_args_dict(task, config_dict):
# reread from task including newly added keys
flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
nested_dict = config_dict._to_dict()
config_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, flat_dict))
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
if running_remotely():
dictionary = self._arguments.copy_to_dict(dictionary)
if not running_remotely():
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
else:
dictionary = self._arguments.copy_from_dict(dictionary)
flat_dict = flatten_dictionary(dictionary)
flat_dict = self._arguments.copy_to_dict(flat_dict)
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
return dictionary
@ -1408,7 +1518,7 @@ class Task(_Task):
cls._get_default_session(),
tasks.GetAllRequest(
project=[project.id],
name=exact_match_regex(task_name),
name=exact_match_regex(task_name) if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags]
)
)
@ -1430,6 +1540,37 @@ class Task(_Task):
log_to_backend=False,
)
@classmethod
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None):
if task_ids:
if isinstance(task_ids, six.string_types):
task_ids = [task_ids]
return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids]
if project_name:
res = cls._send(
cls._get_default_session(),
projects.GetAllRequest(
name=exact_match_regex(project_name)
)
)
project = get_single_result(entity='project', query=project_name, results=res.response.projects)
else:
project = None
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
res = cls._send(
cls._get_default_session(),
tasks.GetAllRequest(
project=[project.id] if project else None,
name=task_name if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags]
)
)
res_tasks = res.response.tasks
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks]
@classmethod
def __get_hash_key(cls, *args):
def normalize(x):

View File

@ -1,3 +1,4 @@
import six
class ProxyDictPostWrite(dict):
@ -5,11 +6,11 @@ class ProxyDictPostWrite(dict):
def __init__(self, update_obj, update_func, *args, **kwargs):
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
self._update_obj = update_obj
self._update_func = None
for k, i in self.items():
if isinstance(i, dict):
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
self._update_obj = update_obj
super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
self._update_func = update_func
def __setitem__(self, key, value):
@ -20,6 +21,20 @@ class ProxyDictPostWrite(dict):
if self._update_func:
self._update_func(self._update_obj, self)
def _to_dict(self):
a_dict = {}
for k, i in self.items():
if isinstance(i, ProxyDictPostWrite):
a_dict[k] = i._to_dict()
else:
a_dict[k] = i
return a_dict
def update(self, E=None, **F):
return super(ProxyDictPostWrite, self).update(
ProxyDictPostWrite(self._update_obj, self._set_callback, **E) if E is not None else
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
class ProxyDictPreWrite(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """
@ -39,8 +54,12 @@ class ProxyDictPreWrite(dict):
super(ProxyDictPreWrite, self).__setitem__(*key_value)
def _set_callback(self, key_value, *_):
if self._update_func:
res = self._update_func(self._update_obj, key_value)
if self._update_func is not None:
if callable(self._update_func):
res = self._update_func(self._update_obj, key_value)
else:
res = self._update_func
if not res:
return None
return res
@ -48,3 +67,40 @@ class ProxyDictPreWrite(dict):
def _nested_callback(self, prefix, key_value):
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
def flatten_dictionary(a_dict, prefix=''):
flat_dict = {}
sep = '/'
basic_types = (float, int, bool, six.string_types, )
for k, v in a_dict.items():
k = str(k)
if isinstance(v, (float, int, bool, six.string_types)):
flat_dict[prefix+k] = v
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
flat_dict[prefix+k] = v
elif isinstance(v, dict):
flat_dict.update(flatten_dictionary(v, prefix=prefix+k+sep))
else:
# this is a mixture of list and dict, or any other object,
# leave it as is, we have nothing to do with it.
flat_dict[prefix+k] = v
return flat_dict
def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
basic_types = (float, int, bool, six.string_types, )
sep = '/'
for k, v in a_dict.items():
k = str(k)
if isinstance(v, (float, int, bool, six.string_types)):
a_dict[k] = flat_dict.get(prefix+k, v)
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
a_dict[k] = flat_dict.get(prefix+k, v)
elif isinstance(v, dict):
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix+k+sep) or v
else:
# this is a mixture of list and dict, or any other object,
# leave it as is, we have nothing to do with it.
a_dict[k] = flat_dict.get(prefix+k, v)
return a_dict