From b5219d98de9506efbdf1c9b0e8c49f1149912a48 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 6 Jul 2019 22:57:21 +0300 Subject: [PATCH] Add new API backward compatibility support --- trains/backend_interface/__init__.py | 2 +- trains/backend_interface/model.py | 16 ++++---- trains/backend_interface/task/__init__.py | 2 +- trains/backend_interface/task/task.py | 50 ++++++++++++++++------- 4 files changed, 46 insertions(+), 24 deletions(-) diff --git a/trains/backend_interface/__init__.py b/trains/backend_interface/__init__.py index 216ed683..e2adb1ba 100644 --- a/trains/backend_interface/__init__.py +++ b/trains/backend_interface/__init__.py @@ -1,2 +1,2 @@ """ High-level abstractions for backend API """ -from .task import Task, TaskStatusEnum, TaskEntry +from .task import Task diff --git a/trains/backend_interface/model.py b/trains/backend_interface/model.py index c6ffb37e..19aabe79 100644 --- a/trains/backend_interface/model.py +++ b/trains/backend_interface/model.py @@ -28,14 +28,16 @@ class _StorageUriMixin(object): self._upload_storage_uri = value.rstrip('/') if value else None -class DummyModel(models.Model, _StorageUriMixin): - def __init__(self, upload_storage_uri=None, *args, **kwargs): - super(DummyModel, self).__init__(*args, **kwargs) - self.upload_storage_uri = upload_storage_uri +def create_dummy_model(upload_storage_uri=None, *args, **kwargs): + class DummyModel(models.Model, _StorageUriMixin): + def __init__(self, upload_storage_uri=None, *args, **kwargs): + super(DummyModel, self).__init__(*args, **kwargs) + self.upload_storage_uri = upload_storage_uri - def update(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) + def update(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + return DummyModel(upload_storage_uri=upload_storage_uri, *args, **kwargs) class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): diff --git a/trains/backend_interface/task/__init__.py b/trains/backend_interface/task/__init__.py index c4ee113a..280e6a2a 100644 --- a/trains/backend_interface/task/__init__.py +++ b/trains/backend_interface/task/__init__.py @@ -1 +1 @@ -from .task import Task, TaskEntry, TaskStatusEnum +from .task import Task diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 9f196f6b..d5f28a6c 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -2,6 +2,7 @@ import collections import itertools import logging +from enum import Enum from threading import RLock, Thread from copy import copy from six.moves.urllib.parse import urlparse, urlunparse @@ -32,12 +33,6 @@ from .log import TaskHandler from .repo import ScriptInfo from ...config import config -TaskStatusEnum = tasks.TaskStatusEnum - - -class TaskEntry(tasks.CreateRequest): - pass - class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ Task manager providing task object access and management. Includes read/write access to task-associated @@ -46,8 +41,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): _anonymous_dataview_id = '__anonymous__' + class TaskTypes(Enum): + def __str__(self): + return str(self.value) + training = 'training' + testing = 'testing' + def __init__(self, session=None, task_id=None, log=None, project_name=None, - task_name=None, task_type=tasks.TaskTypeEnum.training, log_to_backend=True, + task_name=None, task_type=TaskTypes.training, log_to_backend=True, raise_on_validation_errors=True, force_create=False): """ Create a new task instance. @@ -65,7 +66,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :type project_name: str :param task_name: Optional task name, used only if a new task is created. :type project_name: str - :param task_type: Optional task type, used only if a new task is created. Default is custom task. + :param task_type: Optional task type, used only if a new task is created. Default is training task. :type project_name: str (see tasks.TaskTypeEnum) :param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API. This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND. @@ -203,9 +204,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # overwrite it before we have a chance to call edit) self._edit(script=result.script) self.reload() + if result.script.get('requirements'): + self._update_requirements(result.script.get('requirements')) check_package_update_thread.join() - def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training): + def _auto_generate(self, project_name=None, task_name=None, task_type=TaskTypes.training): created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s') project_id = None @@ -216,7 +219,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): req = tasks.CreateRequest( name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'), - type=task_type, + type=tasks.TaskTypeEnum(task_type.value), comment=created_msg, project=project_id, input={'view': {}}, @@ -378,6 +381,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): """ Signal that this task has stopped """ return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors) + def completed(self, ignore_errors=True): + """ Signal that this task has been completed """ + if hasattr(tasks, 'CompletedRequest'): + return self.send(tasks.CompletedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors) + return self.send(tasks.StoppedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors) + def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None): """ Signal that this task has stopped """ return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message), @@ -453,7 +462,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): return uri def _conditionally_start_task(self): - if self.status == TaskStatusEnum.created: + if self.status == tasks.TaskStatusEnum.created: self.started() @property @@ -689,12 +698,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def _edit(self, **kwargs): with self._edit_lock: # Since we ae using forced update, make sure he task status is valid - if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)): + if not self._data or (self.data.status not in (tasks.TaskStatusEnum.created, + tasks.TaskStatusEnum.in_progress)): raise ValueError('Task object can only be updated if created or in_progress') res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False) return res + def _update_requirements(self, requirements): + if not isinstance(requirements, dict): + requirements = {'pip': requirements} + self.data.script.requirements = requirements + self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) + + def _update_script(self, script): + self.data.script = script + self._edit(script=script) + @classmethod def create_new_task(cls, session, task_entry, log=None): """ @@ -702,15 +722,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :param session: Session object used for sending requests to the API :type session: Session :param task_entry: A task entry instance - :type task_entry: TaskEntry + :type task_entry: tasks.CreateRequest :param log: Optional log :type log: logging.Logger :return: A new Task instance """ if isinstance(task_entry, dict): - task_entry = TaskEntry(**task_entry) + task_entry = tasks.CreateRequest(**task_entry) - assert isinstance(task_entry, TaskEntry) + assert isinstance(task_entry, tasks.CreateRequest) res = cls._send(session=session, req=task_entry, log=log) return cls(session, task_id=res.response.id)