Support Task.get_task() without project name (i.e. all projects)

This commit is contained in:
allegroai 2020-05-24 08:10:47 +03:00
parent 7ad4ec2314
commit 70cdb13252

View File

@ -274,8 +274,9 @@ class Task(_Task):
'xgboost': True, 'scikit': True}
:type auto_connect_frameworks: bool or dict
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots? These plots appear in
in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab, with a title of **:resource monitor:**.
:param bool auto_resource_monitoring: Automatically create machine resource monitoring plots?
These plots appear in in the **Trains Web-App (UI)**, **RESULTS** tab, **SCALARS** sub-tab,
with a title of **:resource monitor:**.
The values are:
@ -433,7 +434,8 @@ class Task(_Task):
PatchXGBoostModelIO.update_current_task(task)
if auto_resource_monitoring and not is_sub_process_task_id:
task._resource_monitor = ResourceMonitor(
task, report_mem_used_per_process=not config.get('development.worker.report_global_mem_used', False))
task, report_mem_used_per_process=not config.get(
'development.worker.report_global_mem_used', False))
task._resource_monitor.start()
# make sure all random generators are initialized with new seed
@ -567,7 +569,8 @@ class Task(_Task):
:param dict task_filter: filter and order Tasks. See service.tasks.GetAllRequest for details
:return: The Tasks specified by the parameter combinations (see the parameters).
"""
return cls.__get_tasks(task_ids=task_ids, project_name=project_name, task_name=task_name, **(task_filter or {}))
return cls.__get_tasks(task_ids=task_ids, project_name=project_name,
task_name=task_name, **(task_filter or {}))
@property
def output_uri(self):
@ -640,7 +643,8 @@ class Task(_Task):
"""
assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
raise ValueError("Trains-server does not support DevOps features, "
"upgrade trains-server to 0.12.0 or above")
task_id = source_task if isinstance(source_task, six.string_types) else source_task.id
if not parent:
@ -702,7 +706,8 @@ class Task(_Task):
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
raise ValueError("Trains-server does not support DevOps features, "
"upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
@ -754,14 +759,16 @@ class Task(_Task):
- ``status_reason`` - The reason for the last status change.
- ``status_message`` - Information about the status.
- ``status_changed`` - The last status change date and time in ISO 8601 format.
- ``last_update`` - The last time the Task was created, updated, changed or events for this task were reported.
- ``last_update`` - The last time the Task was created, updated,
changed or events for this task were reported.
- ``execution.queue`` - The Id of the queue where the Task is enqueued. ``null`` indicates not enqueued.
- ``updated`` - The number of Tasks updated (an integer or ``null``).
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
raise ValueError("Trains-server does not support DevOps features, "
"upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
@ -1366,8 +1373,8 @@ class Task(_Task):
"""
Set Task model configuration text/dict
:param config_text: model configuration (unconstrained text string). usually the content of a configuration file.
If `config_text` is not None, `config_dict` must not be provided.
:param config_text: model configuration (unconstrained text string). usually the content
of a configuration file. If `config_text` is not None, `config_dict` must not be provided.
:param config_dict: model configuration parameters dictionary.
If `config_dict` is not None, `config_text` must not be provided.
"""
@ -1532,7 +1539,7 @@ class Task(_Task):
thread = threading.Thread(target=LoggerRoot.flush)
thread.daemon = True
thread.start()
return task
def _get_logger(self, flush_period=NotSet):
@ -1660,7 +1667,8 @@ class Task(_Task):
if running_remotely() and self.is_main_task():
self._arguments.copy_to_parser(parser, parsed_args)
else:
self._arguments.copy_defaults_from_argparse(parser, args=args, namespace=namespace, parsed_args=parsed_args)
self._arguments.copy_defaults_from_argparse(
parser, args=args, namespace=namespace, parsed_args=parsed_args)
return parser
def _connect_dictionary(self, dictionary):
@ -2096,19 +2104,22 @@ class Task(_Task):
if task_id:
return cls(private=cls.__create_protection, task_id=task_id, log_to_backend=False)
res = cls._send(
cls._get_default_session(),
projects.GetAllRequest(
name=exact_match_regex(project_name)
if project_name:
res = cls._send(
cls._get_default_session(),
projects.GetAllRequest(
name=exact_match_regex(project_name)
)
)
)
project = get_single_result(entity='project', query=project_name, results=res.response.projects)
project = get_single_result(entity='project', query=project_name, results=res.response.projects)
else:
project = None
system_tags = 'system_tags' if hasattr(tasks.Task, 'system_tags') else 'tags'
res = cls._send(
cls._get_default_session(),
tasks.GetAllRequest(
project=[project.id],
project=[project.id] if project else None,
name=exact_match_regex(task_name) if task_name else None,
only_fields=['id', 'name', 'last_update', system_tags]
)
@ -2188,7 +2199,8 @@ class Task(_Task):
@classmethod
def __get_last_used_task_id(cls, default_project_name, default_task_name, default_task_type):
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
hash_key = cls.__get_hash_key(
cls._get_api_server(), default_project_name, default_task_name, default_task_type)
# check if we have a cached task_id we can reuse
# it must be from within the last 24h and with the same project/name/type
@ -2213,7 +2225,8 @@ class Task(_Task):
@classmethod
def __update_last_used_task_id(cls, default_project_name, default_task_name, default_task_type, task_id):
hash_key = cls.__get_hash_key(cls._get_api_server(), default_project_name, default_task_name, default_task_type)
hash_key = cls.__get_hash_key(
cls._get_api_server(), default_project_name, default_task_name, default_task_type)
task_id = str(task_id)
# update task session cache