Add new API backward compatibility support

This commit is contained in:
allegroai 2019-07-06 22:57:21 +03:00
parent d38f0ec14d
commit b5219d98de
4 changed files with 46 additions and 24 deletions

View File

@ -1,2 +1,2 @@
""" High-level abstractions for backend API """ """ High-level abstractions for backend API """
from .task import Task, TaskStatusEnum, TaskEntry from .task import Task

View File

@ -28,7 +28,8 @@ class _StorageUriMixin(object):
self._upload_storage_uri = value.rstrip('/') if value else None self._upload_storage_uri = value.rstrip('/') if value else None
class DummyModel(models.Model, _StorageUriMixin): def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
class DummyModel(models.Model, _StorageUriMixin):
def __init__(self, upload_storage_uri=None, *args, **kwargs): def __init__(self, upload_storage_uri=None, *args, **kwargs):
super(DummyModel, self).__init__(*args, **kwargs) super(DummyModel, self).__init__(*args, **kwargs)
self.upload_storage_uri = upload_storage_uri self.upload_storage_uri = upload_storage_uri
@ -36,6 +37,7 @@ class DummyModel(models.Model, _StorageUriMixin):
def update(self, **kwargs): def update(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs)
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):

View File

@ -1 +1 @@
from .task import Task, TaskEntry, TaskStatusEnum from .task import Task

View File

@ -2,6 +2,7 @@
import collections import collections
import itertools import itertools
import logging import logging
from enum import Enum
from threading import RLock, Thread from threading import RLock, Thread
from copy import copy from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse from six.moves.urllib.parse import urlparse, urlunparse
@ -32,12 +33,6 @@ from .log import TaskHandler
from .repo import ScriptInfo from .repo import ScriptInfo
from ...config import config from ...config import config
TaskStatusEnum = tasks.TaskStatusEnum
class TaskEntry(tasks.CreateRequest):
pass
class Task(IdObjectBase, AccessMixin, SetupUploadMixin): class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Task manager providing task object access and management. Includes read/write access to task-associated """ Task manager providing task object access and management. Includes read/write access to task-associated
@ -46,8 +41,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__' _anonymous_dataview_id = '__anonymous__'
class TaskTypes(Enum):
def __str__(self):
return str(self.value)
training = 'training'
testing = 'testing'
def __init__(self, session=None, task_id=None, log=None, project_name=None, def __init__(self, session=None, task_id=None, log=None, project_name=None,
task_name=None, task_type=tasks.TaskTypeEnum.training, log_to_backend=True, task_name=None, task_type=TaskTypes.training, log_to_backend=True,
raise_on_validation_errors=True, force_create=False): raise_on_validation_errors=True, force_create=False):
""" """
Create a new task instance. Create a new task instance.
@ -65,7 +66,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type project_name: str :type project_name: str
:param task_name: Optional task name, used only if a new task is created. :param task_name: Optional task name, used only if a new task is created.
:type project_name: str :type project_name: str
:param task_type: Optional task type, used only if a new task is created. Default is custom task. :param task_type: Optional task type, used only if a new task is created. Default is training task.
:type project_name: str (see tasks.TaskTypeEnum) :type project_name: str (see tasks.TaskTypeEnum)
:param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API. :param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API.
This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND. This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND.
@ -203,9 +204,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# overwrite it before we have a chance to call edit) # overwrite it before we have a chance to call edit)
self._edit(script=result.script) self._edit(script=result.script)
self.reload() self.reload()
if result.script.get('requirements'):
self._update_requirements(result.script.get('requirements'))
check_package_update_thread.join() check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training): def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s') created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None project_id = None
@ -216,7 +219,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
req = tasks.CreateRequest( req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'), name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=task_type, type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg, comment=created_msg,
project=project_id, project=project_id,
input={'view': {}}, input={'view': {}},
@ -378,6 +381,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Signal that this task has stopped """ """ Signal that this task has stopped """
return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors) return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors)
def completed(self, ignore_errors=True):
""" Signal that this task has been completed """
if hasattr(tasks, 'CompletedRequest'):
return self.send(tasks.CompletedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
return self.send(tasks.StoppedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None): def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None):
""" Signal that this task has stopped """ """ Signal that this task has stopped """
return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message), return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message),
@ -453,7 +462,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return uri return uri
def _conditionally_start_task(self): def _conditionally_start_task(self):
if self.status == TaskStatusEnum.created: if self.status == tasks.TaskStatusEnum.created:
self.started() self.started()
@property @property
@ -689,12 +698,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs): def _edit(self, **kwargs):
with self._edit_lock: with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid # Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)): if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created,
tasks.TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress') raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res return res
def _update_requirements(self, requirements):
if not isinstance(requirements, dict):
requirements = {'pip': requirements}
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
def _update_script(self, script):
self.data.script = script
self._edit(script=script)
@classmethod @classmethod
def create_new_task(cls, session, task_entry, log=None): def create_new_task(cls, session, task_entry, log=None):
""" """
@ -702,15 +722,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param session: Session object used for sending requests to the API :param session: Session object used for sending requests to the API
:type session: Session :type session: Session
:param task_entry: A task entry instance :param task_entry: A task entry instance
:type task_entry: TaskEntry :type task_entry: tasks.CreateRequest
:param log: Optional log :param log: Optional log
:type log: logging.Logger :type log: logging.Logger
:return: A new Task instance :return: A new Task instance
""" """
if isinstance(task_entry, dict): if isinstance(task_entry, dict):
task_entry = TaskEntry(**task_entry) task_entry = tasks.CreateRequest(**task_entry)
assert isinstance(task_entry, TaskEntry) assert isinstance(task_entry, tasks.CreateRequest)
res = cls._send(session=session, req=task_entry, log=log) res = cls._send(session=session, req=task_entry, log=log)
return cls(session, task_id=res.response.id) return cls(session, task_id=res.response.id)