Add Task.get_tasks() filtering support

This commit is contained in:
allegroai 2020-04-01 18:54:16 +03:00
parent 581edf1098
commit 172ed62d41

View File

@ -8,11 +8,11 @@ from argparse import ArgumentParser
from tempfile import mkstemp from tempfile import mkstemp
try: try:
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence as CollectionsSequence
except ImportError: except ImportError:
from collections import Callable, Sequence from collections import Callable, Sequence as CollectionsSequence
from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List from typing import Optional, Union, Mapping, Sequence, Any, Dict, List
import psutil import psutil
import six import six
@ -418,8 +418,8 @@ class Task(_Task):
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
@classmethod @classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None): def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None):
# type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task # type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task]
""" """
Returns a list of Task objects, matching requested task name (or partially matching) Returns a list of Task objects, matching requested task name (or partially matching)
@ -428,9 +428,10 @@ class Task(_Task):
:param str task_name: task name (str) in within the selected project :param str task_name: task name (str) in within the selected project
Return any partial match of task_name, regular expressions matching is also supported Return any partial match of task_name, regular expressions matching is also supported
If None is passed, returns all tasks within the project If None is passed, returns all tasks within the project
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
:return: list of Task object :return: list of Task object
""" """
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name) return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {}))
@property @property
def output_uri(self): def output_uri(self):
@ -777,7 +778,7 @@ class Task(_Task):
self.__register_at_exit(None) self.__register_at_exit(None)
def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True):
# type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None # type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None
""" """
Add artifact for the current Task, used mostly for Data Auditing. Add artifact for the current Task, used mostly for Data Auditing.
Currently supported artifacts object types: pandas.DataFrame Currently supported artifacts object types: pandas.DataFrame
@ -788,11 +789,12 @@ class Task(_Task):
:param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria. :param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria.
The default value is True, which equals to all the columns (same as artifact.columns). The default value is True, which equals to all the columns (same as artifact.columns).
""" """
if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True: if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True:
raise ValueError('uniqueness_columns should be a sequence or True') raise ValueError('uniqueness_columns should be a List (sequence) or True')
if isinstance(uniqueness_columns, str): if isinstance(uniqueness_columns, str):
uniqueness_columns = [uniqueness_columns] uniqueness_columns = [uniqueness_columns]
self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) self._artifacts_manager.register_artifact(
name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns)
def unregister_artifact(self, name): def unregister_artifact(self, name):
# type: (str) -> None # type: (str) -> None
@ -1752,11 +1754,27 @@ class Task(_Task):
) )
@classmethod @classmethod
def __get_tasks(cls, task_ids=None, project_name=None, task_name=None): def __get_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs):
if task_ids: if task_ids:
if isinstance(task_ids, six.string_types): if isinstance(task_ids, six.string_types):
task_ids = [task_ids] task_ids = [task_ids]
return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids] return [
cls(private=cls.__create_protection, task_id=task, log_to_backend=False)
for task in task_ids
]
return cls._query_tasks(
project_name=project_name,
task_name=task_name,
**kwargs
)
@classmethod
def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs):
if not task_ids:
task_ids = None
elif isinstance(task_ids, six.string_types):
task_ids = [task_ids]
if project_name: if project_name:
res = cls._send( res = cls._send(
@ -1770,17 +1788,23 @@ class Task(_Task):
project = None project = None
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
only_fields = ['id', 'name', 'last_update', system_tags]
if kwargs and kwargs.get('only_fields'):
only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields))
res = cls._send( res = cls._send(
cls._get_default_session(), cls._get_default_session(),
tasks.GetAllRequest( tasks.GetAllRequest(
id=task_ids,
project=[project.id] if project else None, project=[project.id] if project else None,
name=task_name if task_name else None, name=task_name if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags] only_fields=only_fields,
**kwargs
) )
) )
res_tasks = res.response.tasks
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks] return res.response.tasks
@classmethod @classmethod
def __get_hash_key(cls, *args): def __get_hash_key(cls, *args):