Support Task.get_tasks passing multiple project names

This commit is contained in:
allegroai 2021-07-23 15:56:50 +03:00
parent e0096e6978
commit 08da22296a

View File

@ -35,7 +35,7 @@ from .backend_interface.task.log import TaskHandler
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.repo import ScriptInfo
from .backend_interface.task.models import TaskModels
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive
from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive, get_queue_id
from .binding.absl_bind import PatchAbsl
from .binding.artifacts import Artifacts, Artifact
from .binding.environ_bind import EnvironmentBind, PatchOsFork
@ -760,8 +760,14 @@ class Task(_Task):
)
@classmethod
def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None):
# type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task]
def get_tasks(
cls,
task_ids=None, # type: Optional[Sequence[str]]
project_name=None, # type: Optional[Union[Sequence[str],str]]
task_name=None, # type: Optional[str]
task_filter=None # type: Optional[Dict]
):
# type: (...) -> Sequence[Task]
"""
Get a list of Tasks by one of the following:
@ -774,6 +780,7 @@ class Task(_Task):
:param str project_name: The project name of the Tasks to get. To get the experiment
in all projects, use the default value of ``None``. (Optional)
Use a list of string for multiple optional project names.
:param str task_name: The full name or partial name of the Tasks to match within the specified
``project_name`` (or all projects if ``project_name`` is ``None``).
This method supports regular expressions for name matching. (Optional)
@ -982,14 +989,9 @@ class Task(_Task):
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
if not queue_id:
req = queues.GetAllRequest(name=exact_match_regex(queue_name), only_fields=["id"])
res = cls._send(session=session, req=req)
if not res.response.queues:
queue_id = get_queue_id(session, queue_name)
if not queue_id:
raise ValueError('Could not find queue named "{}"'.format(queue_name))
queue_id = res.response.queues[0].id
if len(res.response.queues) > 1:
LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format(
queue_name, queue_id))
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req)
@ -3304,16 +3306,23 @@ class Task(_Task):
elif isinstance(task_ids, six.string_types):
task_ids = [task_ids]
if project_name:
res = cls._send(
cls._get_default_session(),
projects.GetAllRequest(
name=exact_match_regex(project_name)
)
)
project = get_single_result(entity='project', query=project_name, results=res.response.projects)
if project_name and isinstance(project_name, str):
project_names = [project_name]
else:
project = None
project_names = project_name
project_ids = []
if project_names:
for name in project_names:
res = cls._send(
cls._get_default_session(),
projects.GetAllRequest(
name=exact_match_regex(name)
)
)
project = get_single_result(entity='project', query=name, results=res.response.projects)
if project:
project_ids.append(project.id)
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
only_fields = ['id', 'name', 'last_update', system_tags]
@ -3325,7 +3334,7 @@ class Task(_Task):
cls._get_default_session(),
tasks.GetAllRequest(
id=task_ids,
project=[project.id] if project else kwargs.pop('project', None),
project=project_ids if project_ids else kwargs.pop('project', None),
name=task_name if task_name else None,
only_fields=only_fields,
**kwargs