Add Task.query_tasks, Task.get_task/s tags argument

Fix docstring
This commit is contained in:
allegroai 2021-10-24 17:32:27 +03:00
parent 23b0c500f1
commit e6d3860de0
2 changed files with 158 additions and 22 deletions

View File

@ -15,13 +15,19 @@ class AccessMixin(object):
log = abstractproperty()
def _get_task_property(self, prop_path, raise_on_error=True, log_on_error=True, default=None):
obj = self.data
return self._get_data_property(
prop_path=prop_path, raise_on_error=raise_on_error, log_on_error=log_on_error,
default=default, data=self.data, log=self.log)
@classmethod
def _get_data_property(cls, prop_path, raise_on_error=True, log_on_error=True, default=None, data=None, log=None):
obj = data
props = prop_path.split('.')
for i in range(len(props)):
if not hasattr(obj, props[i]) and (not isinstance(obj, dict) or props[i] not in obj):
msg = 'Task has no %s section defined' % '.'.join(props[:i + 1])
if log_on_error:
self.log.info(msg)
if log_on_error and log:
log.info(msg)
if raise_on_error:
raise ValueError(msg)
return default

View File

@ -18,7 +18,7 @@ try:
except ImportError:
from collections import Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple
from typing import Optional, Union, Mapping, Sequence, Any, Dict, Iterable, TYPE_CHECKING, Callable, Tuple, List
import psutil
import six
@ -726,8 +726,16 @@ class Task(_Task):
return task
@classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None, allow_archived=True, task_filter=None):
# type: (Optional[str], Optional[str], Optional[str], bool, Optional[dict]) -> Task
def get_task(
cls,
task_id=None, # type: Optional[str]
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
tags=None, # type: Optional[Sequence[str]]
allow_archived=True, # type: bool
task_filter=None # type: Optional[dict]
):
# type: (...) -> "Task"
"""
Get a Task by Id, or project name / task name combination.
@ -765,6 +773,8 @@ class Task(_Task):
If specified, ``project_name`` and ``task_name`` are ignored.
:param str project_name: The project name of the Task to get.
:param str task_name: The name of the Task within ``project_name`` to get.
:param list tags: Filter based on the requested list of tags (strings) (Task must have all the listed tags)
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
:param bool allow_archived: Only applicable if *not* using specific ``task_id``,
If True (default) allow to return archived Tasks, if False filter out archived Tasks
:param bool task_filter: Only applicable if *not* using specific ``task_id``,
@ -773,7 +783,7 @@ class Task(_Task):
:return: The Task specified by ID, or project name / experiment name combination.
"""
return cls.__get_task(
task_id=task_id, project_name=project_name, task_name=task_name,
task_id=task_id, project_name=project_name, task_name=task_name, tags=tags,
include_archived=allow_archived, task_filter=task_filter,
)
@ -783,18 +793,73 @@ class Task(_Task):
task_ids=None, # type: Optional[Sequence[str]]
project_name=None, # type: Optional[Union[Sequence[str],str]]
task_name=None, # type: Optional[str]
tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[Dict]
):
# type: (...) -> Sequence[Task]
# type: (...) -> List["Task"]
"""
Get a list of Tasks by one of the following:
Get a list of Tasks objects matching the queries/filters
- A list of specific Task IDs.
- All Tasks in a project matching a full or partial Task name.
- All Tasks in any project matching a full or partial Task name.
- A list of specific Task IDs.
- Filter Tasks based on specific fields:
project name (including partial match), task name (including partial match), tags
Apply Additional advanced filtering with `task_filter`
:param list(str) task_ids: The Ids (system UUID) of experiments to get.
If ``task_ids`` specified, then ``project_name`` and ``task_name`` are ignored.
:param str project_name: The project name of the Tasks to get. To get the experiment
in all projects, use the default value of ``None``. (Optional)
Use a list of string for multiple optional project names.
:param str task_name: The full name or partial name of the Tasks to match within the specified
``project_name`` (or all projects if ``project_name`` is ``None``).
This method supports regular expressions for name matching. (Optional)
:param list(str) task_ids: list of unique task id string (if exists other parameters are ignored)
:param str project_name: project name (str) the task belongs to (use None for all projects)
:param str task_name: task name (str) in within the selected project
Return any partial match of task_name, regular expressions matching is also supported
If None is passed, returns all tasks within the project
:param list tags: Filter based on the requested list of tags (strings) (Task must have all the listed tags)
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
`parent`: (str) filter by parent task-id matching
`search_text`: (str) free text search (in task fields comment/name/id)
`status`: List[str] List of valid statuses
(options are: "created", "queued", "in_progress", "stopped", "published", "completed")
`type`: List[str] List of valid task type
(options are: 'training', 'testing', 'inference', 'data_processing', 'application', 'monitor',
'controller', 'optimizer', 'service', 'qc'. 'custom')
`user`: List[str] Filter based on Task's user owner, provide list of valid user Ids.
`order_by`: List[str] List of field names to order by. When search_text is used,
Use '-' prefix to specify descending order. Optional, recommended when using page
Example: order_by=['-last_update']
`_all_`: dict(fields=[], pattern='') Match string `pattern` (regular expression)
appearing in All `fields`
dict(fields=['script.repository'], pattern='github.com/user')
`_any_`: dict(fields=[], pattern='') Match string `pattern` (regular expression)
appearing in Any of the `fields`
dict(fields=['comment', 'name'], pattern='my comment')
Examples:
{'status': ['stopped'], 'order_by': ["-last_update"]}
{'order_by'=['-last_update'], '_all_'=dict(fields=['script.repository'], pattern='github.com/user'))
:return: The Tasks specified by the parameter combinations (see the parameters).
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, tags=tags,
task_name=task_name, **(task_filter or {}))
@classmethod
def query_tasks(
cls,
project_name=None, # type: Optional[Union[Sequence[str],str]]
task_name=None, # type: Optional[str]
tags=None, # type: Optional[Sequence[str]]
additional_return_fields=None, # type: Optional[Sequence[str]]
task_filter=None, # type: Optional[Dict]
):
# type: (...) -> Union[List[str], List[Dict[str, str]]]
"""
Get a list of Tasks ID matching the specific query/filter.
Notice, if `return_fields` is specified, returns a list of dictionaries with requested fields (dict per Task)
:param str project_name: The project name of the Tasks to get. To get the experiment
in all projects, use the default value of ``None``. (Optional)
@ -802,18 +867,55 @@ class Task(_Task):
:param str task_name: The full name or partial name of the Tasks to match within the specified
``project_name`` (or all projects if ``project_name`` is ``None``).
This method supports regular expressions for name matching. (Optional)
:param list(str) task_ids: list of unique task id string (if exists other parameters are ignored)
:param str project_name: project name (str) the task belongs to (use None for all projects)
:param str task_name: task name (str) in within the selected project
Return any partial match of task_name, regular expressions matching is also supported
If None is passed, returns all tasks within the project
:param list tags: Filter based on the requested list of tags (strings) (Task must have all the listed tags)
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
:param list additional_return_fields: Optional, if not provided return a list of Task IDs.
If provided return dict per Task with the additional requested fields.
Example: returned_fields=['last_updated', 'user', 'script.repository'] will return a list of dict:
[{'id': 'task_id', 'last_update': datetime.datetime(),
'user': 'user_id', 'script.repository': 'https://github.com/user/'}, ]
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
`parent`: (str) filter by parent task-id matching
`search_text`: (str) free text search (in task fields comment/name/id)
`status`: List[str] List of valid statuses
(options are: "created", "queued", "in_progress", "stopped", "published", "completed")
`type`: List[str] List of valid task type
(options are: 'training', 'testing', 'inference', 'data_processing', 'application', 'monitor',
'controller', 'optimizer', 'service', 'qc'. 'custom')
`user`: List[str] Filter based on Task's user owner, provide list of valid user Ids.
`order_by`: List[str] List of field names to order by. When search_text is used,
Use '-' prefix to specify descending order. Optional, recommended when using page
Example: order_by=['-last_update']
`_all_`: dict(fields=[], pattern='') Match string `pattern` (regular expression)
appearing in All `fields`
dict(fields=['script.repository'], pattern='github.com/user')
`_any_`: dict(fields=[], pattern='') Match string `pattern` (regular expression)
appearing in Any of the `fields`
dict(fields=['comment', 'name'], pattern='my comment')
Examples:
{'status': ['stopped'], 'order_by': ["-last_update"]}
{'order_by'=['-last_update'], '_all_'=dict(fields=['script.repository'], pattern='github.com/user'))
:return: The Tasks specified by the parameter combinations (see the parameters).
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name,
task_name=task_name, **(task_filter or {}))
if tags:
task_filter = task_filter or {}
task_filter['tags'] = (task_filter.get('tags') or []) + list(tags)
return_fields = {}
if additional_return_fields:
task_filter = task_filter or {}
return_fields = set(list(additional_return_fields) + ['id'])
task_filter['only_fields'] = (task_filter.get('only_fields') or []) + list(return_fields)
results = cls._query_tasks(project_name=project_name, task_name=task_name, **(task_filter or {}))
return [t.id for t in results] if not additional_return_fields else \
[{k: cls._get_data_property(prop_path=k, data=r, raise_on_error=False, log_on_error=False)
for k in return_fields}
for r in results]
@property
def output_uri(self):
@ -3350,7 +3452,17 @@ class Task(_Task):
cls.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
@classmethod
def __get_task(cls, task_id=None, project_name=None, task_name=None, include_archived=True, task_filter=None):
def __get_task(
cls,
task_id=None, # type: Optional[str]
project_name=None, # type: Optional[str]
task_name=None, # type: Optional[str]
include_archived=True, # type: bool
tags=None, # type: Optional[Sequence[str]]
task_filter=None # type: Optional[dict]
):
# type: (...) -> Task
if task_id:
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
@ -3365,12 +3477,16 @@ class Task(_Task):
else:
project = None
# get default session, before trying to access tasks.Task so that we do not create two sessions.
session = cls._get_default_session()
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
task_filter = task_filter or {}
if not include_archived:
task_filter['system_tags'] = ['-{}'.format(cls.archived_tag)]
task_filter['system_tags'] = (task_filter.get('system_tags') or []) + ['-{}'.format(cls.archived_tag)]
if tags:
task_filter['tags'] = (task_filter.get('tags') or []) + list(tags)
res = cls._send(
cls._get_default_session(),
session,
tasks.GetAllRequest(
project=[project.id] if project else None,
name=exact_match_regex(task_name) if task_name else None,
@ -3388,7 +3504,12 @@ class Task(_Task):
if filtered_tasks:
res_tasks = filtered_tasks
task = get_single_result(entity='task', query=task_name, results=res_tasks, raise_on_error=False)
task = get_single_result(
entity='task',
query={k: v for k, v in dict(
project_name=project_name, task_name=task_name, tags=tags,
include_archived=include_archived, task_filter=task_filter).items() if v},
results=res_tasks, raise_on_error=False)
if not task:
return None
@ -3399,7 +3520,15 @@ class Task(_Task):
)
@classmethod
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs):
def __get_tasks(
cls,
task_ids=None, # type: Optional[Sequence[str]]
project_name=None, # type: Optional[Union[Sequence[str],str]]
task_name=None, # type: Optional[str]
**kwargs # type: Any
):
# type: (...) -> List[Task]
if task_ids:
if isinstance(task_ids, six.string_types):
task_ids = [task_ids]
@ -3434,6 +3563,7 @@ class Task(_Task):
if project:
project_ids.append(project.id)
session = cls._get_default_session()
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
only_fields = ['id', 'name', 'last_update', system_tags]
@ -3441,7 +3571,7 @@ class Task(_Task):
only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields))
res = cls._send(
cls._get_default_session(),
session,
tasks.GetAllRequest(
id=task_ids,
project=project_ids if project_ids else kwargs.pop('project', None),