mirror of
https://github.com/clearml/clearml
synced 2025-05-31 18:48:16 +00:00
Add trains-server new capabilities support
This commit is contained in:
parent
6861aa0113
commit
2104c3ec6b
@ -15,7 +15,7 @@ from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks, models, events, projects
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from ...utilities.pyhocon import ConfigTree, ConfigFactory
|
||||
|
||||
from ..base import IdObjectBase
|
||||
from ..metrics import Metrics, Reporter
|
||||
@ -192,8 +192,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not latest_version[1]:
|
||||
sep = os.linesep
|
||||
self.get_logger().report_text(
|
||||
'TRAINS new package available: UPGRADE to v{} is recommended! '
|
||||
'{}'.format(
|
||||
'TRAINS new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
|
||||
latest_version[0], sep.join(latest_version[2])),
|
||||
)
|
||||
else:
|
||||
@ -790,8 +789,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self._edit(script=script)
|
||||
|
||||
@classmethod
|
||||
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
|
||||
tags=None, parent=None, project=None, log=None, session=None):
|
||||
def _clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
|
||||
tags=None, parent=None, project=None, log=None, session=None):
|
||||
"""
|
||||
Clone a task
|
||||
|
||||
@ -847,7 +846,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
script=task.script
|
||||
)
|
||||
res = cls._send(session=session, log=log, req=req)
|
||||
return res.response.id
|
||||
cloned_task_id = res.response.id
|
||||
|
||||
if task.script and task.script.requirements:
|
||||
cls._send(session=session, log=log, req=tasks.SetRequirementsRequest(
|
||||
task=cloned_task_id, requirements=task.script.requirements))
|
||||
return cloned_task_id
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, session=None, log=None, **kwargs):
|
||||
|
344
trains/task.py
344
trains/task.py
@ -12,7 +12,7 @@ import psutil
|
||||
import six
|
||||
|
||||
from .binding.joblib_bind import PatchedJoblib
|
||||
from .backend_api.services import tasks, projects
|
||||
from .backend_api.services import tasks, projects, queues
|
||||
from .backend_api.session.session import Session
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
from .backend_interface.task import Task as _Task
|
||||
@ -338,137 +338,6 @@ class Task(_Task):
|
||||
raise
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def _reset_current_task_obj(cls):
|
||||
if not cls.__main_task:
|
||||
return
|
||||
task = cls.__main_task
|
||||
cls.__main_task = None
|
||||
if task._dev_worker:
|
||||
task._dev_worker.unregister()
|
||||
task._dev_worker = None
|
||||
|
||||
@classmethod
|
||||
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
|
||||
if not default_project_name or not default_task_name:
|
||||
# get project name and task name from repository name and entry_point
|
||||
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||
if not default_project_name:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parts = result.script['repository'].split('/')
|
||||
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
|
||||
except Exception:
|
||||
default_project_name = 'Untitled'
|
||||
if not default_task_name:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if we force no task reuse from os environment
|
||||
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
|
||||
default_task = None
|
||||
else:
|
||||
# if we have a previous session to use, get the task id from it
|
||||
default_task = cls.__get_last_used_task_id(
|
||||
default_project_name,
|
||||
default_task_name,
|
||||
default_task_type.value,
|
||||
)
|
||||
|
||||
closed_old_task = False
|
||||
default_task_id = None
|
||||
in_dev_mode = not running_remotely()
|
||||
|
||||
if in_dev_mode:
|
||||
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
|
||||
default_task_id = reuse_last_task_id
|
||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
default_task_id = None
|
||||
else:
|
||||
default_task_id = default_task.get('id') if default_task else None
|
||||
|
||||
if default_task_id:
|
||||
try:
|
||||
task = cls(
|
||||
private=cls.__create_protection,
|
||||
task_id=default_task_id,
|
||||
log_to_backend=True,
|
||||
)
|
||||
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
||||
or task.output_model_id or (ARCHIVED_TAG in task_tags)
|
||||
or (cls._development_tag not in task_tags)):
|
||||
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
|
||||
# If the task is archived, or already has an output model,
|
||||
# we shouldn't use it in development mode either
|
||||
default_task_id = None
|
||||
task = None
|
||||
else:
|
||||
# reset the task, so we can update it
|
||||
task.reset(set_started_on_success=False, force=False)
|
||||
# set development tags
|
||||
task.set_system_tags([cls._development_tag])
|
||||
# clear task parameters, they are not cleared by the Task reset
|
||||
task.set_parameters({}, __update=False)
|
||||
# clear the comment, it is not cleared on reset
|
||||
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
|
||||
# clear the input model (and task model design/labels)
|
||||
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
|
||||
task.set_model_config(config_text='')
|
||||
task.set_model_label_enumeration({})
|
||||
task.set_artifacts([])
|
||||
task._set_storage_uri(None)
|
||||
|
||||
except (Exception, ValueError):
|
||||
# we failed reusing task, create a new one
|
||||
default_task_id = None
|
||||
|
||||
# create a new task
|
||||
if not default_task_id:
|
||||
task = cls(
|
||||
private=cls.__create_protection,
|
||||
project_name=default_project_name,
|
||||
task_name=default_task_name,
|
||||
task_type=default_task_type,
|
||||
log_to_backend=True,
|
||||
)
|
||||
|
||||
if in_dev_mode:
|
||||
# update this session, for later use
|
||||
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
|
||||
|
||||
# mark the task as started
|
||||
task.started()
|
||||
# force update of base logger to this current task (this is the main logger task)
|
||||
task._setup_log(replace_existing=True)
|
||||
logger = task.get_logger()
|
||||
if closed_old_task:
|
||||
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
|
||||
# print warning, reusing/creating a task
|
||||
if default_task_id:
|
||||
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
|
||||
else:
|
||||
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
|
||||
|
||||
# update current repository and put warning into logs
|
||||
if in_dev_mode and cls.__detect_repo_async:
|
||||
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
|
||||
task._detect_repo_async_thread.daemon = True
|
||||
task._detect_repo_async_thread.start()
|
||||
else:
|
||||
task._update_repository()
|
||||
|
||||
# make sure everything is in sync
|
||||
task.reload()
|
||||
# make sure we see something in the UI
|
||||
thread = threading.Thread(target=LoggerRoot.flush)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||
"""
|
||||
@ -509,6 +378,86 @@ class Task(_Task):
|
||||
return ReadOnlyDict()
|
||||
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
|
||||
|
||||
@classmethod
|
||||
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
|
||||
"""
|
||||
Clone a task object, create a copy a task.
|
||||
|
||||
:param source_task: Source Task object (or ID) to be cloned
|
||||
:type source_task: Task/str
|
||||
:param name: Optional, New for the new task
|
||||
:type name: str
|
||||
:param comment: Optional, comment for the new task
|
||||
:type comment: str
|
||||
:param parent: Optional parent Task ID of the new task.
|
||||
:type parent: str
|
||||
:param project: Optional project ID of the new task.
|
||||
If None, the new task will inherit the cloned task's project.
|
||||
:type project: str
|
||||
:return: a new cloned Task object
|
||||
"""
|
||||
assert isinstance(source_task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
|
||||
|
||||
task_id = source_task if isinstance(source_task, six.string_types) else source_task.id
|
||||
cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment,
|
||||
parent=parent, project=project)
|
||||
cloned_task = cls.get_task(task_id=cloned_task_id)
|
||||
return cloned_task
|
||||
|
||||
@classmethod
|
||||
def enqueue(cls, task, queue_name=None, queue_id=None):
|
||||
"""
|
||||
Enqueue (send) a task for execution, by adding it to an execution queue
|
||||
|
||||
:param task: Task object (or Task ID) to be enqueued, None if using Task object
|
||||
:type task: Task / str
|
||||
:param str queue_name: Name of the queue in which to enqueue the task.
|
||||
:param str queue_id: ID of the queue in which to enqueue the task. If not provided use queue_name.
|
||||
:return: enqueue response
|
||||
"""
|
||||
assert isinstance(task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
|
||||
|
||||
task_id = task if isinstance(task, six.string_types) else task.id
|
||||
session = cls._get_default_session()
|
||||
if not queue_id:
|
||||
req = queues.GetAllRequest(name=queue_name, only_fields=["id"])
|
||||
res = cls._send(session=session, req=req)
|
||||
if not res.response.queues:
|
||||
raise ValueError('Could not find queue named "{}"'.format(queue_name))
|
||||
queue_id = res.response.queues[0].id
|
||||
if len(res.response.queues) > 1:
|
||||
LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format(
|
||||
queue_name, queue_id))
|
||||
|
||||
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
||||
res = cls._send(session=session, req=req)
|
||||
resp = res.response
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
def dequeue(cls, task):
|
||||
"""
|
||||
Dequeue (remove) task from execution queue.
|
||||
|
||||
:param task: Task object (or Task ID) to be enqueued, None if using Task object
|
||||
:type task: Task / str
|
||||
:return: Dequeue response
|
||||
"""
|
||||
assert isinstance(task, (six.string_types, Task))
|
||||
if not Session.check_min_api_version('2.4'):
|
||||
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
|
||||
|
||||
task_id = task if isinstance(task, six.string_types) else task.id
|
||||
session = cls._get_default_session()
|
||||
req = tasks.DequeueRequest(task=task_id)
|
||||
res = cls._send(session=session, req=req)
|
||||
resp = res.response
|
||||
return resp
|
||||
|
||||
def set_comment(self, comment):
|
||||
"""
|
||||
Set a comment text to the task.
|
||||
@ -812,6 +761,137 @@ class Task(_Task):
|
||||
if secret:
|
||||
Session.default_secret = secret
|
||||
|
||||
@classmethod
|
||||
def _reset_current_task_obj(cls):
|
||||
if not cls.__main_task:
|
||||
return
|
||||
task = cls.__main_task
|
||||
cls.__main_task = None
|
||||
if task._dev_worker:
|
||||
task._dev_worker.unregister()
|
||||
task._dev_worker = None
|
||||
|
||||
@classmethod
|
||||
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
|
||||
if not default_project_name or not default_task_name:
|
||||
# get project name and task name from repository name and entry_point
|
||||
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
|
||||
if not default_project_name:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
parts = result.script['repository'].split('/')
|
||||
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
|
||||
except Exception:
|
||||
default_project_name = 'Untitled'
|
||||
if not default_task_name:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if we force no task reuse from os environment
|
||||
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
|
||||
default_task = None
|
||||
else:
|
||||
# if we have a previous session to use, get the task id from it
|
||||
default_task = cls.__get_last_used_task_id(
|
||||
default_project_name,
|
||||
default_task_name,
|
||||
default_task_type.value,
|
||||
)
|
||||
|
||||
closed_old_task = False
|
||||
default_task_id = None
|
||||
in_dev_mode = not running_remotely()
|
||||
|
||||
if in_dev_mode:
|
||||
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
|
||||
default_task_id = reuse_last_task_id
|
||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
default_task_id = None
|
||||
else:
|
||||
default_task_id = default_task.get('id') if default_task else None
|
||||
|
||||
if default_task_id:
|
||||
try:
|
||||
task = cls(
|
||||
private=cls.__create_protection,
|
||||
task_id=default_task_id,
|
||||
log_to_backend=True,
|
||||
)
|
||||
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
||||
or task.output_model_id or (ARCHIVED_TAG in task_tags)
|
||||
or (cls._development_tag not in task_tags)):
|
||||
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
|
||||
# If the task is archived, or already has an output model,
|
||||
# we shouldn't use it in development mode either
|
||||
default_task_id = None
|
||||
task = None
|
||||
else:
|
||||
# reset the task, so we can update it
|
||||
task.reset(set_started_on_success=False, force=False)
|
||||
# set development tags
|
||||
task.set_system_tags([cls._development_tag])
|
||||
# clear task parameters, they are not cleared by the Task reset
|
||||
task.set_parameters({}, __update=False)
|
||||
# clear the comment, it is not cleared on reset
|
||||
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
|
||||
# clear the input model (and task model design/labels)
|
||||
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
|
||||
task.set_model_config(config_text='')
|
||||
task.set_model_label_enumeration({})
|
||||
task.set_artifacts([])
|
||||
task._set_storage_uri(None)
|
||||
|
||||
except (Exception, ValueError):
|
||||
# we failed reusing task, create a new one
|
||||
default_task_id = None
|
||||
|
||||
# create a new task
|
||||
if not default_task_id:
|
||||
task = cls(
|
||||
private=cls.__create_protection,
|
||||
project_name=default_project_name,
|
||||
task_name=default_task_name,
|
||||
task_type=default_task_type,
|
||||
log_to_backend=True,
|
||||
)
|
||||
|
||||
if in_dev_mode:
|
||||
# update this session, for later use
|
||||
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
|
||||
|
||||
# mark the task as started
|
||||
task.started()
|
||||
# force update of base logger to this current task (this is the main logger task)
|
||||
task._setup_log(replace_existing=True)
|
||||
logger = task.get_logger()
|
||||
if closed_old_task:
|
||||
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
|
||||
# print warning, reusing/creating a task
|
||||
if default_task_id:
|
||||
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
|
||||
else:
|
||||
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
|
||||
|
||||
# update current repository and put warning into logs
|
||||
if in_dev_mode and cls.__detect_repo_async:
|
||||
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
|
||||
task._detect_repo_async_thread.daemon = True
|
||||
task._detect_repo_async_thread.start()
|
||||
else:
|
||||
task._update_repository()
|
||||
|
||||
# make sure everything is in sync
|
||||
task.reload()
|
||||
# make sure we see something in the UI
|
||||
thread = threading.Thread(target=LoggerRoot.flush)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return task
|
||||
|
||||
def _get_logger(self, flush_period=NotSet):
|
||||
# type: (Optional[float]) -> Logger
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user