From 70cdb132528bebe47e5508e684e5bf40112c268c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 24 May 2020 08:10:47 +0300 Subject: [PATCH] Support Task.get_task() without project name (i.e. all projects) --- trains/task.py | 55 +++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/trains/task.py b/trains/task.py index 7bd21a36..b7ff040d 100644 --- a/trains/task.py +++ b/trains/task.py @@ -274,8 +274,9 @@ class Task(_Task): 'xgboost': True, 'scikit': True} :type auto_connect_frameworks: bool or dict - :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots? These plots appear in - in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, with a title of **:resource monitor:**. + :param bool auto_resource_monitoring: Automatically create machine resource monitoring plots? + These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, + with a title of **:resource monitor:**. The values are: @@ -433,7 +434,8 @@ class Task(_Task): PatchXGBoostModelIO.update_current_task(task) if auto_resource_monitoring and not is_sub_process_task_id: task._resource_monitor = ResourceMonitor( - task, report_mem_used_per_process=not config.get('development.worker.report_global_mem_used', False)) + task, report_mem_used_per_process=not config.get( + 'development.worker.report_global_mem_used', False)) task._resource_monitor.start() # make sure all random generators are initialized with new seed @@ -567,7 +569,8 @@ class Task(_Task): :param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details :return: The Tasks specified by the parameter combinations (see the parameters). """ - return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {})) + return cls.__get_tasks(task_ids=task_ids, project_name=project_name, + task_name=task_name, **(task_filter or {})) @property def output_uri(self): @@ -640,7 +643,8 @@ class Task(_Task): """ assert isinstance(source_task, (six.string_types, Task)) if not Session.check_min_api_version('2.4'): - raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above") + raise ValueError("Trains-server does not support DevOps features, " + "upgrade trains-server to 0.12.0 or above") task_id = source_task if isinstance(source_task, six.string_types) else source_task.id if not parent: @@ -702,7 +706,8 @@ class Task(_Task): """ assert isinstance(task, (six.string_types, Task)) if not Session.check_min_api_version('2.4'): - raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above") + raise ValueError("Trains-server does not support DevOps features, " + "upgrade trains-server to 0.12.0 or above") task_id = task if isinstance(task, six.string_types) else task.id session = cls._get_default_session() @@ -754,14 +759,16 @@ class Task(_Task): - ``status_reason`` - The reason for the last status change. - ``status_message`` - Information about the status. - ``status_changed`` - The last status change date and time in ISO 8601 format. - - ``last_update`` - The last time the Task was created, updated, changed or events for this task were reported. + - ``last_update`` - The last time the Task was created, updated, + changed or events for this task were reported. - ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued. - ``updated`` - The number of Tasks updated (an integer or ``null``). """ assert isinstance(task, (six.string_types, Task)) if not Session.check_min_api_version('2.4'): - raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above") + raise ValueError("Trains-server does not support DevOps features, " + "upgrade trains-server to 0.12.0 or above") task_id = task if isinstance(task, six.string_types) else task.id session = cls._get_default_session() @@ -1366,8 +1373,8 @@ class Task(_Task): """ Set Task model configuration text/dict - :param config_text: model configuration (unconstrained text string). usually the content of a configuration file. - If `config_text` is not None, `config_dict` must not be provided. + :param config_text: model configuration (unconstrained text string). usually the content + of a configuration file. If `config_text` is not None, `config_dict` must not be provided. :param config_dict: model configuration parameters dictionary. If `config_dict` is not None, `config_text` must not be provided. """ @@ -1532,7 +1539,7 @@ class Task(_Task): thread = threading.Thread(target=LoggerRoot.flush) thread.daemon = True thread.start() - + return task def _get_logger(self, flush_period=NotSet): @@ -1660,7 +1667,8 @@ class Task(_Task): if running_remotely() and self.is_main_task(): self._arguments.copy_to_parser(parser, parsed_args) else: - self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args) + self._arguments.copy_defaults_from_argparse( + parser, args=args, namespace=namespace, parsed_args=parsed_args) return parser def _connect_dictionary(self, dictionary): @@ -2096,19 +2104,22 @@ class Task(_Task): if task_id: return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False) - res = cls._send( - cls._get_default_session(), - projects.GetAllRequest( - name=exact_match_regex(project_name) + if project_name: + res = cls._send( + cls._get_default_session(), + projects.GetAllRequest( + name=exact_match_regex(project_name) + ) ) - ) - project = get_single_result(entity='project', query=project_name, results=res.response.projects) + project = get_single_result(entity='project', query=project_name, results=res.response.projects) + else: + project = None system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags' res = cls._send( cls._get_default_session(), tasks.GetAllRequest( - project=[project.id], + project=[project.id] if project else None, name=exact_match_regex(task_name) if task_name else None, only_fields=['id', 'name', 'last_update', system_tags] ) @@ -2188,7 +2199,8 @@ class Task(_Task): @classmethod def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type): - hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key( + cls._get_api_server(), default_project_name, default_task_name, default_task_type) # check if we have a cached task_id we can reuse # it must be from within the last 24h and with the same project/name/type @@ -2213,7 +2225,8 @@ class Task(_Task): @classmethod def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id): - hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type) + hash_key = cls.__get_hash_key( + cls._get_api_server(), default_project_name, default_task_name, default_task_type) task_id = str(task_id) # update task session cache