Add Task.get_tasks() filtering support

This commit is contained in:
allegroai 2020-04-01 18:54:16 +03:00
parent 581edf1098
commit 172ed62d41

View File

@ -8,11 +8,11 @@ from argparse import ArgumentParser
from tempfile import mkstemp
try:
from collections.abc import Callable, Sequence
from collections.abc import Callable, Sequence as CollectionsSequence
except ImportError:
from collections import Callable, Sequence
from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List
from typing import Optional, Union, Mapping, Sequence, Any, Dict, List
import psutil
import six
@ -418,8 +418,8 @@ class Task(_Task):
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
@classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None):
# type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task
def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None):
# type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task]
"""
Returns a list of Task objects, matching requested task name (or partially matching)
@ -428,9 +428,10 @@ class Task(_Task):
:param str task_name: task name (str) in within the selected project
Return any partial match of task_name, regular expressions matching is also supported
If None is passed, returns all tasks within the project
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
:return: list of Task object
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name)
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {}))
@property
def output_uri(self):
@ -777,7 +778,7 @@ class Task(_Task):
self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None
# type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None
"""
Add artifact for the current Task, used mostly for Data Auditing.
Currently supported artifacts object types: pandas.DataFrame
@ -788,11 +789,12 @@ class Task(_Task):
:param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria.
The default value is True, which equals to all the columns (same as artifact.columns).
"""
if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True:
raise ValueError('uniqueness_columns should be a sequence or True')
if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True:
raise ValueError('uniqueness_columns should be a List (sequence) or True')
if isinstance(uniqueness_columns, str):
uniqueness_columns = [uniqueness_columns]
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
self._artifacts_manager.register_artifact(
name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
def unregister_artifact(self, name):
# type: (str) -> None
@ -1752,11 +1754,27 @@ class Task(_Task):
)
@classmethod
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None):
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs):
if task_ids:
if isinstance(task_ids, six.string_types):
task_ids = [task_ids]
return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids]
return [
cls(private=cls.__create_protection, task_id=task, log_to_backend=False)
for task in task_ids
]
return cls._query_tasks(
project_name=project_name,
task_name=task_name,
**kwargs
)
@classmethod
def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs):
if not task_ids:
task_ids = None
elif isinstance(task_ids, six.string_types):
task_ids = [task_ids]
if project_name:
res = cls._send(
@ -1770,17 +1788,23 @@ class Task(_Task):
project = None
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
only_fields = ['id', 'name', 'last_update', system_tags]
if kwargs and kwargs.get('only_fields'):
only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields))
res = cls._send(
cls._get_default_session(),
tasks.GetAllRequest(
id=task_ids,
project=[project.id] if project else None,
name=task_name if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags]
only_fields=only_fields,
**kwargs
)
)
res_tasks = res.response.tasks
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks]
return res.response.tasks
@classmethod
def __get_hash_key(cls, *args):