diff --git a/clearml/task.py b/clearml/task.py index 1718d9b4..0e9bdaec 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -35,7 +35,7 @@ from .backend_interface.task.log import TaskHandler from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.repo import ScriptInfo from .backend_interface.task.models import TaskModels -from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive +from .backend_interface.util import get_single_result, exact_match_regex, make_message, mutually_exclusive, get_queue_id from .binding.absl_bind import PatchAbsl from .binding.artifacts import Artifacts, Artifact from .binding.environ_bind import EnvironmentBind, PatchOsFork @@ -760,8 +760,14 @@ class Task(_Task): ) @classmethod - def get_tasks(cls, task_ids=None, project_name=None, task_name=None, task_filter=None): - # type: (Optional[Sequence[str]], Optional[str], Optional[str], Optional[Dict]) -> Sequence[Task] + def get_tasks( + cls, + task_ids=None, # type: Optional[Sequence[str]] + project_name=None, # type: Optional[Union[Sequence[str],str]] + task_name=None, # type: Optional[str] + task_filter=None # type: Optional[Dict] + ): + # type: (...) -> Sequence[Task] """ Get a list of Tasks by one of the following: @@ -774,6 +780,7 @@ class Task(_Task): :param str project_name: The project name of the Tasks to get. To get the experiment in all projects, use the default value of ``None``. (Optional) + Use a list of string for multiple optional project names. :param str task_name: The full name or partial name of the Tasks to match within the specified ``project_name`` (or all projects if ``project_name`` is ``None``). This method supports regular expressions for name matching. (Optional) @@ -982,14 +989,9 @@ class Task(_Task): task_id = task if isinstance(task, six.string_types) else task.id session = cls._get_default_session() if not queue_id: - req = queues.GetAllRequest(name=exact_match_regex(queue_name), only_fields=["id"]) - res = cls._send(session=session, req=req) - if not res.response.queues: + queue_id = get_queue_id(session, queue_name) + if not queue_id: raise ValueError('Could not find queue named "{}"'.format(queue_name)) - queue_id = res.response.queues[0].id - if len(res.response.queues) > 1: - LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format( - queue_name, queue_id)) req = tasks.EnqueueRequest(task=task_id, queue=queue_id) res = cls._send(session=session, req=req) @@ -3304,16 +3306,23 @@ class Task(_Task): elif isinstance(task_ids, six.string_types): task_ids = [task_ids] - if project_name: - res = cls._send( - cls._get_default_session(), - projects.GetAllRequest( - name=exact_match_regex(project_name) - ) - ) - project = get_single_result(entity='project', query=project_name, results=res.response.projects) + if project_name and isinstance(project_name, str): + project_names = [project_name] else: - project = None + project_names = project_name + + project_ids = [] + if project_names: + for name in project_names: + res = cls._send( + cls._get_default_session(), + projects.GetAllRequest( + name=exact_match_regex(name) + ) + ) + project = get_single_result(entity='project', query=name, results=res.response.projects) + if project: + project_ids.append(project.id) system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' only_fields = ['id', 'name', 'last_update', system_tags] @@ -3325,7 +3334,7 @@ class Task(_Task): cls._get_default_session(), tasks.GetAllRequest( id=task_ids, - project=[project.id] if project else kwargs.pop('project', None), + project=project_ids if project_ids else kwargs.pop('project', None), name=task_name if task_name else None, only_fields=only_fields, **kwargs