Support more than 500 results in Task.get_tasks() using the fetch_only_first_page argument (#612)

This commit is contained in:
allegroai 2022-04-15 19:22:50 +03:00
parent a1709d5d41
commit 11242d4029

View File

@ -3628,25 +3628,32 @@ class Task(_Task):
@classmethod @classmethod
def __get_tasks( def __get_tasks(
cls, cls,
task_ids=None, # type: Optional[Sequence[str]] task_ids=None, # type: Optional[Sequence[str]]
project_name=None, # type: Optional[Union[Sequence[str],str]] project_name=None, # type: Optional[Union[Sequence[str],str]]
task_name=None, # type: Optional[str] task_name=None, # type: Optional[str]
**kwargs # type: Any **kwargs # type: Any
): ):
# type: (...) -> List[Task] # type: (...) -> List[Task]
if task_ids: if task_ids:
if isinstance(task_ids, six.string_types): if isinstance(task_ids, six.string_types):
task_ids = [task_ids] task_ids = [task_ids]
return [cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) return [cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) for task_id in task_ids]
for task_id in task_ids]
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) queried_tasks = cls._query_tasks(
for task in cls._query_tasks(project_name=project_name, task_name=task_name, **kwargs)] project_name=project_name, task_name=task_name, fetch_only_first_page=True, **kwargs
)
if len(queried_tasks) == 500:
LoggerRoot.get_base_logger().warning(
"Too many requests when calling Task.get_tasks()."
" Returning only the first 500 results."
" Use Task.query_tasks() to fetch all task IDs"
)
return [cls(private=cls.__create_protection, task_id=task.id, log_to_backend=False) for task in queried_tasks]
@classmethod @classmethod
def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, **kwargs): def _query_tasks(cls, task_ids=None, project_name=None, task_name=None, fetch_only_first_page=False, **kwargs):
if not task_ids: if not task_ids:
task_ids = None task_ids = None
elif isinstance(task_ids, six.string_types): elif isinstance(task_ids, six.string_types):
@ -3677,18 +3684,25 @@ class Task(_Task):
if kwargs and kwargs.get('only_fields'): if kwargs and kwargs.get('only_fields'):
only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields)) only_fields = list(set(kwargs.pop('only_fields')) | set(only_fields))
res = cls._send( ret_tasks = []
session, page = -1
tasks.GetAllRequest( page_size = 500
id=task_ids, while page == -1 or (len(res.response.tasks) == page_size and not fetch_only_first_page):
project=project_ids if project_ids else kwargs.pop('project', None), page += 1
name=task_name if task_name else kwargs.pop('name', None), res = cls._send(
only_fields=only_fields, session,
**kwargs tasks.GetAllRequest(
id=task_ids,
project=project_ids if project_ids else kwargs.pop("project", None),
name=task_name if task_name else kwargs.pop("name", None),
only_fields=only_fields,
page=page,
page_size=page_size,
**kwargs
),
) )
) ret_tasks.extend(res.response.tasks)
return ret_tasks
return res.response.tasks
@classmethod @classmethod
def __get_hash_key(cls, *args): def __get_hash_key(cls, *args):