Add trains-server new capabilities support

This commit is contained in:
allegroai 2019-10-25 15:15:22 +03:00
parent 6861aa0113
commit 2104c3ec6b
2 changed files with 222 additions and 138 deletions

View File

@ -15,7 +15,7 @@ from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects from ...backend_api.services import tasks, models, events, projects
from pathlib2 import Path from pathlib2 import Path
from pyhocon import ConfigTree, ConfigFactory from ...utilities.pyhocon import ConfigTree, ConfigFactory
from ..base import IdObjectBase from ..base import IdObjectBase
from ..metrics import Metrics, Reporter from ..metrics import Metrics, Reporter
@ -192,8 +192,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not latest_version[1]: if not latest_version[1]:
sep = os.linesep sep = os.linesep
self.get_logger().report_text( self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended! ' 'TRAINS new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
'{}'.format(
latest_version[0], sep.join(latest_version[2])), latest_version[0], sep.join(latest_version[2])),
) )
else: else:
@ -790,8 +789,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(script=script) self._edit(script=script)
@classmethod @classmethod
def clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None, def _clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None,
tags=None, parent=None, project=None, log=None, session=None): tags=None, parent=None, project=None, log=None, session=None):
""" """
Clone a task Clone a task
@ -847,7 +846,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
script=task.script script=task.script
) )
res = cls._send(session=session, log=log, req=req) res = cls._send(session=session, log=log, req=req)
return res.response.id cloned_task_id = res.response.id
if task.script and task.script.requirements:
cls._send(session=session, log=log, req=tasks.SetRequirementsRequest(
task=cloned_task_id, requirements=task.script.requirements))
return cloned_task_id
@classmethod @classmethod
def get_all(cls, session=None, log=None, **kwargs): def get_all(cls, session=None, log=None, **kwargs):

View File

@ -12,7 +12,7 @@ import psutil
import six import six
from .binding.joblib_bind import PatchedJoblib from .binding.joblib_bind import PatchedJoblib
from .backend_api.services import tasks, projects from .backend_api.services import tasks, projects, queues
from .backend_api.session.session import Session from .backend_api.session.session import Session
from .backend_interface.model import Model as BackendModel from .backend_interface.model import Model as BackendModel
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
@ -338,137 +338,6 @@ class Task(_Task):
raise raise
return task return task
@classmethod
def _reset_current_task_obj(cls):
if not cls.__main_task:
return
task = cls.__main_task
cls.__main_task = None
if task._dev_worker:
task._dev_worker.unregister()
task._dev_worker = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name:
# noinspection PyBroadException
try:
parts = result.script['repository'].split('/')
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
except Exception:
default_project_name = 'Untitled'
if not default_task_name:
# noinspection PyBroadException
try:
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
except Exception:
pass
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
default_task = None
else:
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
closed_old_task = False
default_task_id = None
in_dev_mode = not running_remotely()
if in_dev_mode:
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
else:
default_task_id = default_task.get('id') if default_task else None
if default_task_id:
try:
task = cls(
private=cls.__create_protection,
task_id=default_task_id,
log_to_backend=True,
)
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or (cls._development_tag not in task_tags)):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else:
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_system_tags([cls._development_tag])
# clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
# clear the input model (and task model design/labels)
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
task.set_model_config(config_text='')
task.set_model_label_enumeration({})
task.set_artifacts([])
task._set_storage_uri(None)
except (Exception, ValueError):
# we failed reusing task, create a new one
default_task_id = None
# create a new task
if not default_task_id:
task = cls(
private=cls.__create_protection,
project_name=default_project_name,
task_name=default_task_name,
task_type=default_task_type,
log_to_backend=True,
)
if in_dev_mode:
# update this session, for later use
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
# mark the task as started
task.started()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
if default_task_id:
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
else:
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start()
else:
task._update_repository()
# make sure everything is in sync
task.reload()
# make sure we see something in the UI
thread = threading.Thread(target=LoggerRoot.flush)
thread.daemon = True
thread.start()
return task
@classmethod @classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None): def get_task(cls, task_id=None, project_name=None, task_name=None):
""" """
@ -509,6 +378,86 @@ class Task(_Task):
return ReadOnlyDict() return ReadOnlyDict()
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts]) return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
@classmethod
def clone(cls, source_task=None, name=None, comment=None, parent=None, project=None):
"""
Clone a task object, create a copy a task.
:param source_task: Source Task object (or ID) to be cloned
:type source_task: Task/str
:param name: Optional, New for the new task
:type name: str
:param comment: Optional, comment for the new task
:type comment: str
:param parent: Optional parent Task ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project.
:type project: str
:return: a new cloned Task object
"""
assert isinstance(source_task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = source_task if isinstance(source_task, six.string_types) else source_task.id
cloned_task_id = cls._clone_task(cloned_task_id=task_id, name=name, comment=comment,
parent=parent, project=project)
cloned_task = cls.get_task(task_id=cloned_task_id)
return cloned_task
@classmethod
def enqueue(cls, task, queue_name=None, queue_id=None):
"""
Enqueue (send) a task for execution, by adding it to an execution queue
:param task: Task object (or Task ID) to be enqueued, None if using Task object
:type task: Task / str
:param str queue_name: Name of the queue in which to enqueue the task.
:param str queue_id: ID of the queue in which to enqueue the task. If not provided use queue_name.
:return: enqueue response
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
if not queue_id:
req = queues.GetAllRequest(name=queue_name, only_fields=["id"])
res = cls._send(session=session, req=req)
if not res.response.queues:
raise ValueError('Could not find queue named "{}"'.format(queue_name))
queue_id = res.response.queues[0].id
if len(res.response.queues) > 1:
LoggerRoot.get_base_logger().info("Multiple queues with name={}, selecting queue id={}".format(
queue_name, queue_id))
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req)
resp = res.response
return resp
@classmethod
def dequeue(cls, task):
"""
Dequeue (remove) task from execution queue.
:param task: Task object (or Task ID) to be enqueued, None if using Task object
:type task: Task / str
:return: Dequeue response
"""
assert isinstance(task, (six.string_types, Task))
if not Session.check_min_api_version('2.4'):
raise ValueError("Trains-server does not support DevOps features, upgrade trains-server to 0.12.0 or above")
task_id = task if isinstance(task, six.string_types) else task.id
session = cls._get_default_session()
req = tasks.DequeueRequest(task=task_id)
res = cls._send(session=session, req=req)
resp = res.response
return resp
def set_comment(self, comment): def set_comment(self, comment):
""" """
Set a comment text to the task. Set a comment text to the task.
@ -812,6 +761,137 @@ class Task(_Task):
if secret: if secret:
Session.default_secret = secret Session.default_secret = secret
@classmethod
def _reset_current_task_obj(cls):
if not cls.__main_task:
return
task = cls.__main_task
cls.__main_task = None
if task._dev_worker:
task._dev_worker.unregister()
task._dev_worker = None
@classmethod
def _create_dev_task(cls, default_project_name, default_task_name, default_task_type, reuse_last_task_id):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
result, _ = ScriptInfo.get(create_requirements=False, check_uncommitted=False)
if not default_project_name:
# noinspection PyBroadException
try:
parts = result.script['repository'].split('/')
default_project_name = (parts[-1] or parts[-2]).replace('.git', '') or 'Untitled'
except Exception:
default_project_name = 'Untitled'
if not default_task_name:
# noinspection PyBroadException
try:
default_task_name = os.path.splitext(os.path.basename(result.script['entry_point']))[0]
except Exception:
pass
# if we force no task reuse from os environment
if DEV_TASK_NO_REUSE.get() or not reuse_last_task_id:
default_task = None
else:
# if we have a previous session to use, get the task id from it
default_task = cls.__get_last_used_task_id(
default_project_name,
default_task_name,
default_task_type.value,
)
closed_old_task = False
default_task_id = None
in_dev_mode = not running_remotely()
if in_dev_mode:
if isinstance(reuse_last_task_id, str) and reuse_last_task_id:
default_task_id = reuse_last_task_id
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
default_task_id = None
else:
default_task_id = default_task.get('id') if default_task else None
if default_task_id:
try:
task = cls(
private=cls.__create_protection,
task_id=default_task_id,
log_to_backend=True,
)
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or (cls._development_tag not in task_tags)):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
default_task_id = None
task = None
else:
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_system_tags([cls._development_tag])
# clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset
task.set_comment(make_message('Auto-generated at %(time)s by %(user)s@%(host)s'))
# clear the input model (and task model design/labels)
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
task.set_model_config(config_text='')
task.set_model_label_enumeration({})
task.set_artifacts([])
task._set_storage_uri(None)
except (Exception, ValueError):
# we failed reusing task, create a new one
default_task_id = None
# create a new task
if not default_task_id:
task = cls(
private=cls.__create_protection,
project_name=default_project_name,
task_name=default_task_name,
task_type=default_task_type,
log_to_backend=True,
)
if in_dev_mode:
# update this session, for later use
cls.__update_last_used_task_id(default_project_name, default_task_name, default_task_type.value, task.id)
# mark the task as started
task.started()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
if default_task_id:
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
else:
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
task._detect_repo_async_thread = threading.Thread(target=task._update_repository)
task._detect_repo_async_thread.daemon = True
task._detect_repo_async_thread.start()
else:
task._update_repository()
# make sure everything is in sync
task.reload()
# make sure we see something in the UI
thread = threading.Thread(target=LoggerRoot.flush)
thread.daemon = True
thread.start()
return task
def _get_logger(self, flush_period=NotSet): def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger # type: (Optional[float]) -> Logger
""" """