mirror of
https://github.com/clearml/clearml
synced 2025-04-07 22:24:30 +00:00
Add Task connect_configuration, connect_label_enumeration, and support for nested dictionary with connect
Add Task.get_tasks
This commit is contained in:
parent
15683b5b43
commit
0a45d2094f
@ -12,22 +12,32 @@ task = Task.init(project_name='examples', task_name='Manual model configuration'
|
||||
# create a model
|
||||
model = torch.nn.Module
|
||||
|
||||
# store dictionary of definition for a specific network design
|
||||
# Connect a local configuration file
|
||||
config_file = 'samples/sample.json'
|
||||
config_file = task.connect_configuration(config_file)
|
||||
# then read configuration as usual, the backend will contain a copy of it.
|
||||
# later when executing remotely, the returned `config_file` will be a temporary file
|
||||
# containing a new copy of the configuration retrieved form the backend
|
||||
# # model_config_dict = json.load(open(config_file, 'rt'))
|
||||
|
||||
# Or Store dictionary of definition for a specific network design
|
||||
model_config_dict = {
|
||||
'value': 13.37,
|
||||
'dict': {'sub_value': 'string'},
|
||||
'dict': {'sub_value': 'string', 'sub_integer': 11},
|
||||
'list_of_ints': [1, 2, 3, 4],
|
||||
}
|
||||
task.set_model_config(config_dict=model_config_dict)
|
||||
model_config_dict = task.connect_configuration(model_config_dict)
|
||||
|
||||
# or read form a config file (this will override the previous configuration dictionary)
|
||||
# task.set_model_config(config_text='this is just a blob\nof text from a configuration file')
|
||||
# We now update the dictionary after connecting it, and the changes will be tracked as well.
|
||||
model_config_dict['new value'] = 10
|
||||
model_config_dict['value'] *= model_config_dict['new value']
|
||||
|
||||
# store the label enumeration the model is training for
|
||||
task.set_model_label_enumeration({'background': 0, 'cat': 1, 'dog': 2})
|
||||
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
|
||||
# store the label enumeration of the training model
|
||||
labels = {'background': 0, 'cat': 1, 'dog': 2}
|
||||
task.connect_label_enumeration(labels)
|
||||
|
||||
# storing the model, it will have the task network configuration and label enumeration
|
||||
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
|
||||
|
||||
torch.save(model, os.path.join(gettempdir(), "model"))
|
||||
print('Model saved')
|
||||
|
8
examples/samples/sample.json
Normal file
8
examples/samples/sample.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"list_of_ints": [1,2,3,4],
|
||||
"dict": {
|
||||
"sub_value": "string",
|
||||
"sub_integer": 11
|
||||
},
|
||||
"value": 13.37
|
||||
}
|
189
trains/task.py
189
trains/task.py
@ -5,6 +5,9 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from tempfile import mkstemp
|
||||
|
||||
from pathlib2 import Path
|
||||
from collections import OrderedDict, Callable
|
||||
from typing import Optional
|
||||
|
||||
@ -40,6 +43,8 @@ from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
from .utilities.dicts import ReadOnlyDict
|
||||
from .utilities.proxy_object import ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, \
|
||||
nested_from_flat_dictionary
|
||||
|
||||
|
||||
class Task(_Task):
|
||||
@ -354,13 +359,27 @@ class Task(_Task):
|
||||
"""
|
||||
Returns Task object based on either, task_id (system uuid) or task name
|
||||
|
||||
:param task_id: unique task id string (if exists other parameters are ignored)
|
||||
:param project_name: project name (str) the task belongs to
|
||||
:param task_name: task name (str) in within the selected project
|
||||
:return: Task() object
|
||||
:param str task_id: unique task id string (if exists other parameters are ignored)
|
||||
:param str project_name: project name (str) the task belongs to
|
||||
:param str task_name: task name (str) in within the selected project
|
||||
:return: Task object
|
||||
"""
|
||||
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
|
||||
|
||||
@classmethod
|
||||
def get_tasks(cls, task_ids=None, project_name=None, task_name=None):
|
||||
"""
|
||||
Returns a list of Task objects, matching requested task name (or partially matching)
|
||||
|
||||
:param list(str) task_ids: list of unique task id string (if exists other parameters are ignored)
|
||||
:param str project_name: project name (str) the task belongs to (use None for all projects)
|
||||
:param str task_name: task name (str) in within the selected project
|
||||
Return any partial match of task_name, regular expressions matching is also supported
|
||||
If None is passed, returns all tasks within the project
|
||||
:return: list of Task object
|
||||
"""
|
||||
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name)
|
||||
|
||||
@property
|
||||
def output_uri(self):
|
||||
return self.storage_uri
|
||||
@ -385,9 +404,12 @@ class Task(_Task):
|
||||
"""
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
return ReadOnlyDict()
|
||||
if not self.data.execution or not self.data.execution.artifacts:
|
||||
return ReadOnlyDict()
|
||||
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
|
||||
artifacts_pairs = []
|
||||
if self.data.execution and self.data.execution.artifacts:
|
||||
artifacts_pairs = [(a.key, Artifact(a)) for a in self.data.execution.artifacts]
|
||||
if self._artifacts_manager:
|
||||
artifacts_pairs += list(self._artifacts_manager.registered_artifacts.items())
|
||||
return ReadOnlyDict(artifacts_pairs)
|
||||
|
||||
@classmethod
|
||||
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
|
||||
@ -469,19 +491,6 @@ class Task(_Task):
|
||||
resp = res.response
|
||||
return resp
|
||||
|
||||
def set_comment(self, comment):
|
||||
"""
|
||||
Set a comment text to the task.
|
||||
|
||||
In remote, this is a no-op.
|
||||
|
||||
:param comment: The comment of the task
|
||||
:type comment: str
|
||||
"""
|
||||
if not running_remotely() or not self.is_main_task():
|
||||
self._edit(comment=comment)
|
||||
self.reload()
|
||||
|
||||
def add_tags(self, tags):
|
||||
"""
|
||||
Add tags to this task. Old tags are not deleted
|
||||
@ -526,6 +535,93 @@ class Task(_Task):
|
||||
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
def connect_configuration(self, configuration):
|
||||
"""
|
||||
Connect a configuration dict / file (pathlib.Path / str) with the Task
|
||||
Connecting configuration file should be called before reading the configuration file.
|
||||
When an output model will be created it will include the content of the configuration dict/file
|
||||
|
||||
Example local file:
|
||||
config_file = task.connect_configuration(config_file)
|
||||
my_params = json.load(open(config_file,'rt'))
|
||||
|
||||
Example parameter dictionary:
|
||||
my_params = task.connect_configuration(my_params)
|
||||
|
||||
:param (dict, pathlib.Path/str) configuration: usually configuration file used in the model training process
|
||||
configuration can be either dict or path to local file.
|
||||
If dict is provided, it will be stored in json alike format (hocon) editable in the UI
|
||||
If pathlib2.Path / string is provided the content of the file will be stored
|
||||
Notice: local path must be relative path
|
||||
(and in remote execution, the content of the file will be overwritten with the content brought from the UI)
|
||||
:return: configuration object
|
||||
If dict was provided, a dictionary will be returned
|
||||
If pathlib2.Path / string was provided, a path to a local configuration file is returned
|
||||
"""
|
||||
if not isinstance(configuration, (dict, Path, six.string_types)):
|
||||
raise ValueError("connect_configuration supports `dict`, `str` and 'Path' types, "
|
||||
"{} is not supported".format(type(configuration)))
|
||||
|
||||
# parameter dictionary
|
||||
if isinstance(configuration, dict):
|
||||
def _update_config_dict(task, config_dict):
|
||||
task.set_model_config(config_dict=config_dict)
|
||||
|
||||
if not running_remotely():
|
||||
self.set_model_config(config_dict=configuration)
|
||||
configuration = ProxyDictPostWrite(self, _update_config_dict, **configuration)
|
||||
else:
|
||||
configuration.clear()
|
||||
configuration.update(self.get_model_config_dict())
|
||||
configuration = ProxyDictPreWrite(False, False, **configuration)
|
||||
return configuration
|
||||
|
||||
# it is a path to a local file
|
||||
if not running_remotely():
|
||||
# check if not absolute path
|
||||
configuration_path = Path(configuration)
|
||||
if not configuration_path.is_file():
|
||||
ValueError("Configuration file does not exist")
|
||||
try:
|
||||
with open(configuration_path.as_posix(), 'rt') as f:
|
||||
configuration_text = f.read()
|
||||
except Exception:
|
||||
raise ValueError("Could not connect configuration file {}, file could not be read".format(
|
||||
configuration_path.as_posix()))
|
||||
self.set_model_config(config_text=configuration_text)
|
||||
return configuration
|
||||
else:
|
||||
configuration_text = self.get_model_config_text()
|
||||
configuration_path = Path(configuration)
|
||||
fd, local_filename = mkstemp(prefix='trains_task_config_',
|
||||
suffix=configuration_path.suffixes[-1] if
|
||||
configuration_path.suffixes else '.txt')
|
||||
os.write(fd, configuration_text.encode('utf-8'))
|
||||
os.close(fd)
|
||||
return Path(local_filename) if isinstance(configuration, Path) else local_filename
|
||||
|
||||
def connect_label_enumeration(self, enumeration):
|
||||
"""
|
||||
Connect a label enumeration dictionary with the Task
|
||||
|
||||
When an output model is created it will store the model label enumeration dictionary
|
||||
|
||||
:param dict enumeration: dictionary of string to integer, enumerating the model output integer to labels
|
||||
example: {'background': 0 , 'person': 1}
|
||||
:return: enumeration dict
|
||||
"""
|
||||
if not isinstance(enumeration, dict):
|
||||
raise ValueError("connect_label_enumeration supports only `dict` type, "
|
||||
"{} is not supported".format(type(enumeration)))
|
||||
|
||||
if not running_remotely():
|
||||
self.set_model_label_enumeration(enumeration)
|
||||
else:
|
||||
# pop everything
|
||||
enumeration.clear()
|
||||
enumeration.update(self.get_labels_enumeration())
|
||||
return enumeration
|
||||
|
||||
def get_logger(self):
|
||||
# type: () -> Logger
|
||||
"""
|
||||
@ -1029,12 +1125,26 @@ class Task(_Task):
|
||||
self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args)
|
||||
|
||||
def _connect_dictionary(self, dictionary):
|
||||
def _update_args_dict(task, config_dict):
|
||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict))
|
||||
|
||||
def _refresh_args_dict(task, config_dict):
|
||||
# reread from task including newly added keys
|
||||
flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict))
|
||||
nested_dict = config_dict._to_dict()
|
||||
config_dict.clear()
|
||||
config_dict.update(nested_from_flat_dictionary(nested_dict, flat_dict))
|
||||
|
||||
self._try_set_connected_parameter_type(self._ConnectedParametersType.dictionary)
|
||||
|
||||
if running_remotely():
|
||||
dictionary = self._arguments.copy_to_dict(dictionary)
|
||||
if not running_remotely():
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary))
|
||||
dictionary = ProxyDictPostWrite(self, _update_args_dict, **dictionary)
|
||||
else:
|
||||
dictionary = self._arguments.copy_from_dict(dictionary)
|
||||
flat_dict = flatten_dictionary(dictionary)
|
||||
flat_dict = self._arguments.copy_to_dict(flat_dict)
|
||||
dictionary = nested_from_flat_dictionary(dictionary, flat_dict)
|
||||
dictionary = ProxyDictPostWrite(self, _refresh_args_dict, **dictionary)
|
||||
|
||||
return dictionary
|
||||
|
||||
@ -1408,7 +1518,7 @@ class Task(_Task):
|
||||
cls._get_default_session(),
|
||||
tasks.GetAllRequest(
|
||||
project=[project.id],
|
||||
name=exact_match_regex(task_name),
|
||||
name=exact_match_regex(task_name) if task_name else None,
|
||||
only_fields=['id', 'name', 'last_update', system_tags]
|
||||
)
|
||||
)
|
||||
@ -1430,6 +1540,37 @@ class Task(_Task):
|
||||
log_to_backend=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None):
|
||||
if task_ids:
|
||||
if isinstance(task_ids, six.string_types):
|
||||
task_ids = [task_ids]
|
||||
return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids]
|
||||
|
||||
if project_name:
|
||||
res = cls._send(
|
||||
cls._get_default_session(),
|
||||
projects.GetAllRequest(
|
||||
name=exact_match_regex(project_name)
|
||||
)
|
||||
)
|
||||
project = get_single_result(entity='project', query=project_name, results=res.response.projects)
|
||||
else:
|
||||
project = None
|
||||
|
||||
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
|
||||
res = cls._send(
|
||||
cls._get_default_session(),
|
||||
tasks.GetAllRequest(
|
||||
project=[project.id] if project else None,
|
||||
name=task_name if task_name else None,
|
||||
only_fields=['id', 'name', 'last_update', system_tags]
|
||||
)
|
||||
)
|
||||
res_tasks = res.response.tasks
|
||||
|
||||
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks]
|
||||
|
||||
@classmethod
|
||||
def __get_hash_key(cls, *args):
|
||||
def normalize(x):
|
||||
|
@ -1,3 +1,4 @@
|
||||
import six
|
||||
|
||||
|
||||
class ProxyDictPostWrite(dict):
|
||||
@ -5,11 +6,11 @@ class ProxyDictPostWrite(dict):
|
||||
|
||||
def __init__(self, update_obj, update_func, *args, **kwargs):
|
||||
super(ProxyDictPostWrite, self).__init__(*args, **kwargs)
|
||||
self._update_obj = update_obj
|
||||
self._update_func = None
|
||||
for k, i in self.items():
|
||||
if isinstance(i, dict):
|
||||
self.update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
|
||||
self._update_obj = update_obj
|
||||
super(ProxyDictPostWrite, self).update({k: ProxyDictPostWrite(update_obj, self._set_callback, **i)})
|
||||
self._update_func = update_func
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
@ -20,6 +21,20 @@ class ProxyDictPostWrite(dict):
|
||||
if self._update_func:
|
||||
self._update_func(self._update_obj, self)
|
||||
|
||||
def _to_dict(self):
|
||||
a_dict = {}
|
||||
for k, i in self.items():
|
||||
if isinstance(i, ProxyDictPostWrite):
|
||||
a_dict[k] = i._to_dict()
|
||||
else:
|
||||
a_dict[k] = i
|
||||
return a_dict
|
||||
|
||||
def update(self, E=None, **F):
|
||||
return super(ProxyDictPostWrite, self).update(
|
||||
ProxyDictPostWrite(self._update_obj, self._set_callback, **E) if E is not None else
|
||||
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
|
||||
|
||||
|
||||
class ProxyDictPreWrite(dict):
|
||||
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||
@ -39,8 +54,12 @@ class ProxyDictPreWrite(dict):
|
||||
super(ProxyDictPreWrite, self).__setitem__(*key_value)
|
||||
|
||||
def _set_callback(self, key_value, *_):
|
||||
if self._update_func:
|
||||
res = self._update_func(self._update_obj, key_value)
|
||||
if self._update_func is not None:
|
||||
if callable(self._update_func):
|
||||
res = self._update_func(self._update_obj, key_value)
|
||||
else:
|
||||
res = self._update_func
|
||||
|
||||
if not res:
|
||||
return None
|
||||
return res
|
||||
@ -48,3 +67,40 @@ class ProxyDictPreWrite(dict):
|
||||
|
||||
def _nested_callback(self, prefix, key_value):
|
||||
return self._set_callback((prefix+'.'+key_value[0], key_value[1],))
|
||||
|
||||
|
||||
def flatten_dictionary(a_dict, prefix=''):
|
||||
flat_dict = {}
|
||||
sep = '/'
|
||||
basic_types = (float, int, bool, six.string_types, )
|
||||
for k, v in a_dict.items():
|
||||
k = str(k)
|
||||
if isinstance(v, (float, int, bool, six.string_types)):
|
||||
flat_dict[prefix+k] = v
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
flat_dict[prefix+k] = v
|
||||
elif isinstance(v, dict):
|
||||
flat_dict.update(flatten_dictionary(v, prefix=prefix+k+sep))
|
||||
else:
|
||||
# this is a mixture of list and dict, or any other object,
|
||||
# leave it as is, we have nothing to do with it.
|
||||
flat_dict[prefix+k] = v
|
||||
return flat_dict
|
||||
|
||||
|
||||
def nested_from_flat_dictionary(a_dict, flat_dict, prefix=''):
|
||||
basic_types = (float, int, bool, six.string_types, )
|
||||
sep = '/'
|
||||
for k, v in a_dict.items():
|
||||
k = str(k)
|
||||
if isinstance(v, (float, int, bool, six.string_types)):
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
elif isinstance(v, (list, tuple)) and all([isinstance(i, basic_types) for i in v]):
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
elif isinstance(v, dict):
|
||||
a_dict[k] = nested_from_flat_dictionary(v, flat_dict, prefix=prefix+k+sep) or v
|
||||
else:
|
||||
# this is a mixture of list and dict, or any other object,
|
||||
# leave it as is, we have nothing to do with it.
|
||||
a_dict[k] = flat_dict.get(prefix+k, v)
|
||||
return a_dict
|
||||
|
Loading…
Reference in New Issue
Block a user