diff --git a/trains/task.py b/trains/task.py index 9171cf85..54795015 100644 --- a/trains/task.py +++ b/trains/task.py @@ -8,11 +8,11 @@ from argparse import ArgumentParser from tempfile import mkstemp try: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Sequence as CollectionsSequence except ImportError: - from collections import Callable, Sequence + from collections import Callable, Sequence as CollectionsSequence -from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List +from typing import Optional, Union, Mapping, Sequence, Any, Dict, List import psutil import six @@ -418,8 +418,8 @@ class Task(_Task): return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) @classmethod - def get_tasks(cls, task_ids=None, project_name=None, task_name=None): - # type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task + def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None): + # type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task] """ Returns a list of Task objects, matching requested task name (or partially matching) @@ -428,9 +428,10 @@ class Task(_Task): :param str task_name: task name (str) in within the selected project Return any partial match of task_name, regular expressions matching is also supported If None is passed, returns all tasks within the project + :param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details :return: list of Task object """ - return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name) + return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {})) @property def output_uri(self): @@ -777,7 +778,7 @@ class Task(_Task): self.__register_at_exit(None) def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): - # type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None + # type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None """ Add artifact for the current Task, used mostly for Data Auditing. Currently supported artifacts object types: pandas.DataFrame @@ -788,11 +789,12 @@ class Task(_Task): :param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria. The default value is True, which equals to all the columns (same as artifact.columns). """ - if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True: - raise ValueError('uniqueness_columns should be a sequence or True') + if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True: + raise ValueError('uniqueness_columns should be a List (sequence) or True') if isinstance(uniqueness_columns, str): uniqueness_columns = [uniqueness_columns] - self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) + self._artifacts_manager.register_artifact( + name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) def unregister_artifact(self, name): # type: (str) -> None @@ -1752,11 +1754,27 @@ class Task(_Task): ) @classmethod - def __get_tasks(cls, task_ids=None, project_name=None, task_name=None): + def __get_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): if task_ids: if isinstance(task_ids, six.string_types): task_ids = [task_ids] - return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids] + return [ + cls(private=cls.__create_protection, task_id=task, log_to_backend=False) + for task in task_ids + ] + + return cls._query_tasks( + project_name=project_name, + task_name=task_name, + **kwargs + ) + + @classmethod + def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): + if not task_ids: + task_ids = None + elif isinstance(task_ids, six.string_types): + task_ids = [task_ids] if project_name: res = cls._send( @@ -1770,17 +1788,23 @@ class Task(_Task): project = None system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' + only_fields = ['id', 'name', 'last_update', system_tags] + + if kwargs and kwargs.get('only_fields'): + only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields)) + res = cls._send( cls._get_default_session(), tasks.GetAllRequest( + id=task_ids, project=[project.id] if project else None, name=task_name if task_name else None, - only_fields=['id', 'name', 'last_update', system_tags] + only_fields=only_fields, + **kwargs ) ) - res_tasks = res.response.tasks - return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks] + return res.response.tasks @classmethod def __get_hash_key(cls, *args):