Add trains-server new capabilities support

This commit is contained in:
allegroai 2019-10-25 15:15:22 +03:00
parent 6861aa0113
commit 2104c3ec6b
2 changed files with 222 additions and 138 deletions

View File

@ -15,7 +15,7 @@ from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
from pathlib2 import Path
from pyhocon import ConfigTree, ConfigFactory
from ...utilities.pyhocon import ConfigTree, ConfigFactory
from ..base import IdObjectBase
from ..metrics import Metrics, Reporter
@ -192,8 +192,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not latest_version[1]:
sep = os.linesep
self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended! '
'{}'.format(
'TRAINS new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
latest_version[0], sep.join(latest_version[2])),
)
else:
@ -790,8 +789,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(script=script)
@classmethod
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
tags=None, parent=None, project=None, log=None, session=None):
def _clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
tags=None, parent=None, project=None, log=None, session=None):
"""
Clone a task
@ -847,7 +846,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
script=task.script
)
res = cls._send(session=session, log=log, req=req)
return res.response.id
cloned_task_id = res.response.id
if task.script and task.script.requirements:
cls._send(session=session, log=log, req=tasks.SetRequirementsRequest(
task=cloned_task_id, requirements=task.script.requirements))
return cloned_task_id
@classmethod
def get_all(cls, session=None, log=None, **kwargs):

View File

@ -12,7 +12,7 @@ import psutil
import six
from .binding.joblib_bind import PatchedJoblib
from .backend_api.services import tasks, projects
from .backend_api.services import tasks, projects, queues
from .backend_api.session.session import Session
from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task
@ -338,137 +338,6 @@ class Task(_Task):
raise
return task
@classmethod
def _reset_current_task_obj(cls):
if not cls.__main_task:
return
task = cls.__main_task
cls.__main_task = None
if task._dev_worker:
task._dev_worker.unregister()
task._dev_worker = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name:
# noinspection PyBroadException
try:
parts = result.script['repository'].split('/')
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
except Exception:
default_project_name = 'Untitled'
if not default_task_name:
# noinspection PyBroadException
try:
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
except Exception:
pass
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
default_task = None
else:
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
closed_old_task = False
default_task_id = None
in_dev_mode = not running_remotely()
if in_dev_mode:
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
else:
default_task_id = default_task.get('id') if default_task else None
if default_task_id:
try:
task = cls(
private=cls.__create_protection,
task_id=default_task_id,
log_to_backend=True,
)
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or (cls._development_tag not in task_tags)):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else:
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_system_tags([cls._development_tag])
# clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
# clear the input model (and task model design/labels)
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
task.set_model_config(config_text='')
task.set_model_label_enumeration({})
task.set_artifacts([])
task._set_storage_uri(None)
except (Exception, ValueError):
# we failed reusing task, create a new one
default_task_id = None
# create a new task
if not default_task_id:
task = cls(
private=cls.__create_protection,
project_name=default_project_name,
task_name=default_task_name,
task_type=default_task_type,
log_to_backend=True,
)
if in_dev_mode:
# update this session, for later use
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
# mark the task as started
task.started()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
if default_task_id:
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
else:
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start()
else:
task._update_repository()
# make sure everything is in sync
task.reload()
# make sure we see something in the UI
thread = threading.Thread(target=LoggerRoot.flush)
thread.daemon = True
thread.start()
return task
@classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None):
"""
@ -509,6 +378,86 @@ class Task(_Task):
return ReadOnlyDict()
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
@classmethod
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
"""
Clone a task object, create a copy a task.
:param source_task: Source Task object (or ID) to be cloned
:type source_task: Task/str
:param name: Optional, New for the new task
:type name: str
:param comment: Optional, comment for the new task
:type comment: str
:param parent: Optional parent Task ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project.
:type project: str
:return: a new cloned Task object
"""
assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = source_task if isinstance(source_task, six.string_types) else source_task.id
cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment,
parent=parent, project=project)
cloned_task = cls.get_task(task_id=cloned_task_id)
return cloned_task
@classmethod
def enqueue(cls, task, queue_name=None, queue_id=None):
"""
Enqueue (send) a task for execution, by adding it to an execution queue
:param task: Task object (or Task ID) to be enqueued, None if using Task object
:type task: Task / str
:param str queue_name: Name of the queue in which to enqueue the task.
:param str queue_id: ID of the queue in which to enqueue the task. If not provided use queue_name.
:return: enqueue response
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
if not queue_id:
req = queues.GetAllRequest(name=queue_name, only_fields=["id"])
res = cls._send(session=session, req=req)
if not res.response.queues:
raise ValueError('Could not find queue named "{}"'.format(queue_name))
queue_id = res.response.queues[0].id
if len(res.response.queues) > 1:
LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format(
queue_name, queue_id))
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req)
resp = res.response
return resp
@classmethod
def dequeue(cls, task):
"""
Dequeue (remove) task from execution queue.
:param task: Task object (or Task ID) to be enqueued, None if using Task object
:type task: Task / str
:return: Dequeue response
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
req = tasks.DequeueRequest(task=task_id)
res = cls._send(session=session, req=req)
resp = res.response
return resp
def set_comment(self, comment):
"""
Set a comment text to the task.
@ -812,6 +761,137 @@ class Task(_Task):
if secret:
Session.default_secret = secret
@classmethod
def _reset_current_task_obj(cls):
if not cls.__main_task:
return
task = cls.__main_task
cls.__main_task = None
if task._dev_worker:
task._dev_worker.unregister()
task._dev_worker = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name:
# noinspection PyBroadException
try:
parts = result.script['repository'].split('/')
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
except Exception:
default_project_name = 'Untitled'
if not default_task_name:
# noinspection PyBroadException
try:
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
except Exception:
pass
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
default_task = None
else:
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
closed_old_task = False
default_task_id = None
in_dev_mode = not running_remotely()
if in_dev_mode:
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
else:
default_task_id = default_task.get('id') if default_task else None
if default_task_id:
try:
task = cls(
private=cls.__create_protection,
task_id=default_task_id,
log_to_backend=True,
)
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or (cls._development_tag not in task_tags)):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else:
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_system_tags([cls._development_tag])
# clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
# clear the input model (and task model design/labels)
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
task.set_model_config(config_text='')
task.set_model_label_enumeration({})
task.set_artifacts([])
task._set_storage_uri(None)
except (Exception, ValueError):
# we failed reusing task, create a new one
default_task_id = None
# create a new task
if not default_task_id:
task = cls(
private=cls.__create_protection,
project_name=default_project_name,
task_name=default_task_name,
task_type=default_task_type,
log_to_backend=True,
)
if in_dev_mode:
# update this session, for later use
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
# mark the task as started
task.started()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
if default_task_id:
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
else:
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start()
else:
task._update_repository()
# make sure everything is in sync
task.reload()
# make sure we see something in the UI
thread = threading.Thread(target=LoggerRoot.flush)
thread.daemon = True
thread.start()
return task
def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
"""