From 11242d40290ebd67d6205c0db562fddb0c603729 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 15 Apr 2022 19:22:50 +0300 Subject: [PATCH] Support more than 500 results in `Task.get_tasks()` using the `fetch_only_first_page` argument (#612) --- clearml/task.py | 56 ++++++++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/clearml/task.py b/clearml/task.py index 26efacdd..7046aab4 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -3628,25 +3628,32 @@ class Task(_Task): @classmethod def __get_tasks( - cls, - task_ids=None, # type: Optional[Sequence[str]] - project_name=None, # type: Optional[Union[Sequence[str],str]] - task_name=None, # type: Optional[str] - **kwargs # type: Any + cls, + task_ids=None, # type: Optional[Sequence[str]] + project_name=None, # type: Optional[Union[Sequence[str],str]] + task_name=None, # type: Optional[str] + **kwargs # type: Any ): # type: (...) -> List[Task] if task_ids: if isinstance(task_ids, six.string_types): task_ids = [task_ids] - return [cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) - for task_id in task_ids] + return [cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) for task_id in task_ids] - return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) - for task in cls._query_tasks(project_name=project_name, task_name=task_name, **kwargs)] + queried_tasks = cls._query_tasks( + project_name=project_name, task_name=task_name, fetch_only_first_page=True, **kwargs + ) + if len(queried_tasks) == 500: + LoggerRoot.get_base_logger().warning( + "Too many requests when calling Task.get_tasks()." + " Returning only the first 500 results." + " Use Task.query_tasks() to fetch all task IDs" + ) + return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in queried_tasks] @classmethod - def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): + def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, fetch_only_first_page=False, **kwargs): if not task_ids: task_ids = None elif isinstance(task_ids, six.string_types): @@ -3677,18 +3684,25 @@ class Task(_Task): if kwargs and kwargs.get('only_fields'): only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields)) - res = cls._send( - session, - tasks.GetAllRequest( - id=task_ids, - project=project_ids if project_ids else kwargs.pop('project', None), - name=task_name if task_name else kwargs.pop('name', None), - only_fields=only_fields, - **kwargs + ret_tasks = [] + page = -1 + page_size = 500 + while page == -1 or (len(res.response.tasks) == page_size and not fetch_only_first_page): + page += 1 + res = cls._send( + session, + tasks.GetAllRequest( + id=task_ids, + project=project_ids if project_ids else kwargs.pop("project", None), + name=task_name if task_name else kwargs.pop("name", None), + only_fields=only_fields, + page=page, + page_size=page_size, + **kwargs + ), ) - ) - - return res.response.tasks + ret_tasks.extend(res.response.tasks) + return ret_tasks @classmethod def __get_hash_key(cls, *args):