mirror of
				https://github.com/clearml/clearml
				synced 2025-06-26 18:16:07 +00:00 
			
		
		
		
	Add Task.get_tasks() filtering support
This commit is contained in:
		
							parent
							
								
									581edf1098
								
							
						
					
					
						commit
						172ed62d41
					
				| @ -8,11 +8,11 @@ from argparse import ArgumentParser | ||||
| from tempfile import mkstemp | ||||
| 
 | ||||
| try: | ||||
|     from collections.abc import Callable, Sequence | ||||
|     from collections.abc import Callable, Sequence as CollectionsSequence | ||||
| except ImportError: | ||||
|     from collections import Callable, Sequence | ||||
|     from collections import Callable, Sequence as CollectionsSequence | ||||
| 
 | ||||
| from typing import Optional, Union, Mapping, Sequence as TSequence, Any, Dict, List | ||||
| from typing import Optional, Union, Mapping, Sequence, Any, Dict, List | ||||
| 
 | ||||
| import psutil | ||||
| import six | ||||
| @ -418,8 +418,8 @@ class Task(_Task): | ||||
|         return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def get_tasks(cls, task_ids=None, project_name=None, task_name=None): | ||||
|         # type: (Optional[TSequence[str]], Optional[str], Optional[str]) -> Task | ||||
|     def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None): | ||||
|         # type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task] | ||||
|         """ | ||||
|         Returns a list of Task objects, matching requested task name (or partially matching) | ||||
| 
 | ||||
| @ -428,9 +428,10 @@ class Task(_Task): | ||||
|         :param str task_name: task name (str) in within the selected project | ||||
|             Return any partial match of task_name, regular expressions matching is also supported | ||||
|             If None is passed, returns all tasks within the project | ||||
|         :param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details | ||||
|         :return: list of Task object | ||||
|         """ | ||||
|         return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name) | ||||
|         return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {})) | ||||
| 
 | ||||
|     @property | ||||
|     def output_uri(self): | ||||
| @ -777,7 +778,7 @@ class Task(_Task): | ||||
|             self.__register_at_exit(None) | ||||
| 
 | ||||
|     def register_artifact(self, name, artifact, metadata=None, uniqueness_columns=True): | ||||
|         # type: (str, "pandas.DataFrame", Dict, Union[bool, TSequence[str]]) -> None | ||||
|         # type: (str, "pandas.DataFrame", Dict, Union[bool, Sequence[str]]) -> None | ||||
|         """ | ||||
|         Add artifact for the current Task, used mostly for Data Auditing. | ||||
|         Currently supported artifacts object types: pandas.DataFrame | ||||
| @ -788,11 +789,12 @@ class Task(_Task): | ||||
|         :param Sequence uniqueness_columns: Sequence of columns for artifact uniqueness comparison criteria. | ||||
|             The default value is True, which equals to all the columns (same as artifact.columns). | ||||
|         """ | ||||
|         if not isinstance(uniqueness_columns, Sequence) and uniqueness_columns is not True: | ||||
|             raise ValueError('uniqueness_columns should be a sequence or True') | ||||
|         if not isinstance(uniqueness_columns, CollectionsSequence) and uniqueness_columns is not True: | ||||
|             raise ValueError('uniqueness_columns should be a List (sequence) or True') | ||||
|         if isinstance(uniqueness_columns, str): | ||||
|             uniqueness_columns = [uniqueness_columns] | ||||
|         self._artifacts_manager.register_artifact(name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) | ||||
|         self._artifacts_manager.register_artifact( | ||||
|             name=name, artifact=artifact, metadata=metadata, uniqueness_columns=uniqueness_columns) | ||||
| 
 | ||||
|     def unregister_artifact(self, name): | ||||
|         # type: (str) -> None | ||||
| @ -1752,11 +1754,27 @@ class Task(_Task): | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def __get_tasks(cls, task_ids=None, project_name=None, task_name=None): | ||||
|     def __get_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): | ||||
|         if task_ids: | ||||
|             if isinstance(task_ids, six.string_types): | ||||
|                 task_ids = [task_ids] | ||||
|             return [cls(private=cls.__create_protection, task_id=i, log_to_backend=False) for i in task_ids] | ||||
|             return [ | ||||
|                 cls(private=cls.__create_protection, task_id=task, log_to_backend=False) | ||||
|                 for task in task_ids | ||||
|             ] | ||||
| 
 | ||||
|         return cls._query_tasks( | ||||
|             project_name=project_name, | ||||
|             task_name=task_name, | ||||
|             **kwargs | ||||
|         ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): | ||||
|         if not task_ids: | ||||
|             task_ids = None | ||||
|         elif isinstance(task_ids, six.string_types): | ||||
|             task_ids = [task_ids] | ||||
| 
 | ||||
|         if project_name: | ||||
|             res = cls._send( | ||||
| @ -1770,17 +1788,23 @@ class Task(_Task): | ||||
|             project = None | ||||
| 
 | ||||
|         system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' | ||||
|         only_fields = ['id', 'name', 'last_update', system_tags] | ||||
| 
 | ||||
|         if kwargs and kwargs.get('only_fields'): | ||||
|             only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields)) | ||||
| 
 | ||||
|         res = cls._send( | ||||
|             cls._get_default_session(), | ||||
|             tasks.GetAllRequest( | ||||
|                 id=task_ids, | ||||
|                 project=[project.id] if project else None, | ||||
|                 name=task_name if task_name else None, | ||||
|                 only_fields=['id', 'name', 'last_update', system_tags] | ||||
|                 only_fields=only_fields, | ||||
|                 **kwargs | ||||
|             ) | ||||
|         ) | ||||
|         res_tasks = res.response.tasks | ||||
| 
 | ||||
|         return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in res_tasks] | ||||
|         return res.response.tasks | ||||
| 
 | ||||
|     @classmethod | ||||
|     def __get_hash_key(cls, *args): | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 allegroai
						allegroai