Add new API backward compatibility support

This commit is contained in:
allegroai 2019-07-06 22:57:21 +03:00
parent d38f0ec14d
commit b5219d98de
4 changed files with 46 additions and 24 deletions

View File

@ -1,2 +1,2 @@
""" High-level abstractions for backend API """
from .task import Task, TaskStatusEnum, TaskEntry
from .task import Task

View File

@ -28,14 +28,16 @@ class _StorageUriMixin(object):
self._upload_storage_uri = value.rstrip('/') if value else None
class DummyModel(models.Model, _StorageUriMixin):
def __init__(self, upload_storage_uri=None, *args, **kwargs):
super(DummyModel, self).__init__(*args, **kwargs)
self.upload_storage_uri = upload_storage_uri
def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
class DummyModel(models.Model, _StorageUriMixin):
def __init__(self, upload_storage_uri=None, *args, **kwargs):
super(DummyModel, self).__init__(*args, **kwargs)
self.upload_storage_uri = upload_storage_uri
def update(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def update(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs)
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):

View File

@ -1 +1 @@
from .task import Task, TaskEntry, TaskStatusEnum
from .task import Task

View File

@ -2,6 +2,7 @@
import collections
import itertools
import logging
from enum import Enum
from threading import RLock, Thread
from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse
@ -32,12 +33,6 @@ from .log import TaskHandler
from .repo import ScriptInfo
from ...config import config
TaskStatusEnum = tasks.TaskStatusEnum
class TaskEntry(tasks.CreateRequest):
pass
class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Task manager providing task object access and management. Includes read/write access to task-associated
@ -46,8 +41,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
_anonymous_dataview_id = '__anonymous__'
class TaskTypes(Enum):
def __str__(self):
return str(self.value)
training = 'training'
testing = 'testing'
def __init__(self, session=None, task_id=None, log=None, project_name=None,
task_name=None, task_type=tasks.TaskTypeEnum.training, log_to_backend=True,
task_name=None, task_type=TaskTypes.training, log_to_backend=True,
raise_on_validation_errors=True, force_create=False):
"""
Create a new task instance.
@ -65,7 +66,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:type project_name: str
:param task_name: Optional task name, used only if a new task is created.
:type project_name: str
:param task_type: Optional task type, used only if a new task is created. Default is custom task.
:param task_type: Optional task type, used only if a new task is created. Default is training task.
:type project_name: str (see tasks.TaskTypeEnum)
:param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API.
This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND.
@ -203,9 +204,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
self.reload()
if result.script.get('requirements'):
self._update_requirements(result.script.get('requirements'))
check_package_update_thread.join()
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
@ -216,7 +219,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=task_type,
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
@ -378,6 +381,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Signal that this task has stopped """
return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors)
def completed(self, ignore_errors=True):
""" Signal that this task has been completed """
if hasattr(tasks, 'CompletedRequest'):
return self.send(tasks.CompletedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
return self.send(tasks.StoppedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None):
""" Signal that this task has stopped """
return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message),
@ -453,7 +462,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return uri
def _conditionally_start_task(self):
if self.status == TaskStatusEnum.created:
if self.status == tasks.TaskStatusEnum.created:
self.started()
@property
@ -689,12 +698,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _edit(self, **kwargs):
with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created,
tasks.TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res
def _update_requirements(self, requirements):
if not isinstance(requirements, dict):
requirements = {'pip': requirements}
self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
def _update_script(self, script):
self.data.script = script
self._edit(script=script)
@classmethod
def create_new_task(cls, session, task_entry, log=None):
"""
@ -702,15 +722,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param session: Session object used for sending requests to the API
:type session: Session
:param task_entry: A task entry instance
:type task_entry: TaskEntry
:type task_entry: tasks.CreateRequest
:param log: Optional log
:type log: logging.Logger
:return: A new Task instance
"""
if isinstance(task_entry, dict):
task_entry = TaskEntry(**task_entry)
task_entry = tasks.CreateRequest(**task_entry)
assert isinstance(task_entry, TaskEntry)
assert isinstance(task_entry, tasks.CreateRequest)
res = cls._send(session=session, req=task_entry, log=log)
return cls(session, task_id=res.response.id)