Add Task.get_configuration_objects() add Task.get_task() filtering support

This commit is contained in:
allegroai 2021-04-25 10:41:29 +03:00
parent 07fca61572
commit 696034ac75
3 changed files with 61 additions and 11 deletions

View File

@ -1085,8 +1085,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: True if the parameter was deleted successfully :return: True if the parameter was deleted successfully
""" """
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
raise ValueError("Delete hyper parameter is not supported by your clearml-server, " raise ValueError(
"upgrade to the latest version") "Delete hyper-parameter is not supported by your clearml-server, "
"upgrade to the latest version")
with self._edit_lock: with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1]) paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest( res = self.send(tasks.DeleteHyperParamsRequest(
@ -1580,6 +1582,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" """
return self._get_configuration_text(name) return self._get_configuration_text(name)
def get_configuration_objects(self):
# type: () -> Optional[Mapping[str, str]]
"""
Get the Task's configuration object section as a blob of text
Use only for automation (externally), otherwise use `Task.connect_configuration`.
:return: The Task's configurations as a
dict (config name as key) and text blob as value (unconstrained text string)
"""
if not Session.check_min_api_version('2.9'):
raise ValueError(
"Multiple configurations are not supported with the current 'clearml-server', "
"please upgrade to the latest version")
configuration = self.data.configuration or {}
return {k: v.value for k, v in configuration.items()}
def set_configuration_object(self, name, config_text=None, description=None, config_type=None): def set_configuration_object(self, name, config_text=None, description=None, config_type=None):
# type: (str, Optional[str], Optional[str], Optional[str]) -> None # type: (str, Optional[str], Optional[str], Optional[str]) -> None
""" """
@ -1851,7 +1870,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True) mutually_exclusive(config_dict=config_dict, config_text=config_text, _check_none=True)
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
raise ValueError("Multiple configurations is not supported with the current 'clearml-server', " raise ValueError("Multiple configurations are not supported with the current 'clearml-server', "
"please upgrade to the latest version") "please upgrade to the latest version")
if description: if description:
@ -1875,7 +1894,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return None if configuration name is not valid. return None if configuration name is not valid.
""" """
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
raise ValueError("Multiple configurations is not supported with the current 'clearml-server', " raise ValueError("Multiple configurations are not supported with the current 'clearml-server', "
"please upgrade to the latest version") "please upgrade to the latest version")
configuration = self.data.configuration or {} configuration = self.data.configuration or {}

View File

@ -1,8 +1,9 @@
import hashlib import hashlib
import json
import re import re
import sys import sys
from zlib import crc32 from zlib import crc32
from typing import Optional, Union, Sequence from typing import Optional, Union, Sequence, Dict
from pathlib2 import Path from pathlib2 import Path
from six.moves.urllib.parse import quote, urlparse, urlunparse from six.moves.urllib.parse import quote, urlparse, urlunparse
@ -117,6 +118,25 @@ def hash_text(text, seed=1337, hash_func='md5'):
return h.hexdigest() return h.hexdigest()
def hash_dict(a_dict, seed=1337, hash_func='md5'):
# type: (Dict, Union[int, str], str) -> str
"""
Return hash_func (crc32/md5/sha1/sha256/sha384/sha512) hash of the dict values
(dict must be JSON serializable)
:param a_dict: a dictionary to hash
:param seed: use prefix seed for hashing
:param hash_func: hashing function. currently supported md5 sha256
:return: hashed string
"""
assert hash_func in ('crc32', 'md5', 'sha256', 'sha256', 'sha384', 'sha512')
repr_string = json.dumps(a_dict, sort_keys=True)
if hash_func == 'crc32':
return crc32text(repr_string, seed=seed)
else:
return hash_text(repr_string, seed=seed, hash_func=hash_func)
def is_windows(): def is_windows():
""" """
:return: True if currently running on windows OS :return: True if currently running on windows OS

View File

@ -696,8 +696,8 @@ class Task(_Task):
return task return task
@classmethod @classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None): def get_task(cls, task_id=None, project_name=None, task_name=None, allow_archived=True, task_filter=None):
# type: (Optional[str], Optional[str], Optional[str]) -> Task # type: (Optional[str], Optional[str], Optional[str], bool, Optional[dict]) -> Task
""" """
Get a Task by Id, or project name / task name combination. Get a Task by Id, or project name / task name combination.
@ -732,13 +732,20 @@ class Task(_Task):
train_task.get_logger().report_scalar('title', 'series', value=x * 2, iteration=x) train_task.get_logger().report_scalar('title', 'series', value=x * 2, iteration=x)
:param str task_id: The Id (system UUID) of the experiment to get. :param str task_id: The Id (system UUID) of the experiment to get.
If specified, ``project_name`` and ``task_name`` are ignored. If specified, ``project_name`` and ``task_name`` are ignored.
:param str project_name: The project name of the Task to get. :param str project_name: The project name of the Task to get.
:param str task_name: The name of the Task within ``project_name`` to get. :param str task_name: The name of the Task within ``project_name`` to get.
:param bool allow_archived: Only applicable if *not* using specific ``task_id``,
If True (default) allow to return archived Tasks, if False filter out archived Tasks
:param bool task_filter: Only applicable if *not* using specific ``task_id``,
Pass additional query filters, on top of project/name. See details in Task.get_tasks.
:return: The Task specified by ID, or project name / experiment name combination. :return: The Task specified by ID, or project name / experiment name combination.
""" """
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) return cls.__get_task(
task_id=task_id, project_name=project_name, task_name=task_name,
include_archived=allow_archived, task_filter=task_filter,
)
@classmethod @classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None): def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None):
@ -3199,7 +3206,7 @@ class Task(_Task):
cls.__register_at_exit(None, only_remove_signal_and_exception_hooks=True) cls.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
@classmethod @classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None): def __get_task(cls, task_id=None, project_name=None, task_name=None, include_archived=True, task_filter=None):
if task_id: if task_id:
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
@ -3215,12 +3222,16 @@ class Task(_Task):
project = None project = None
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
task_filter = task_filter or {}
if not include_archived:
task_filter['system_tags'] = ['-{}'.format(cls.archived_tag)]
res = cls._send( res = cls._send(
cls._get_default_session(), cls._get_default_session(),
tasks.GetAllRequest( tasks.GetAllRequest(
project=[project.id] if project else None, project=[project.id] if project else None,
name=exact_match_regex(task_name) if task_name else None, name=exact_match_regex(task_name) if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags] only_fields=['id', 'name', 'last_update', system_tags],
**task_filter
) )
) )
res_tasks = res.response.tasks res_tasks = res.response.tasks