From bca3a6e55655d328e04873413659f9a329eaba87 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 18:05:44 +0200 Subject: [PATCH] Set default task active duration to None Move endpoints to new API version Support tasks.move and models.move for moving tasks and models into projects Support new project name in tasks.clone Improve task active duration migration --- apiserver/apimodels/base.py | 6 + apiserver/apimodels/tasks.py | 1 + apiserver/bll/project/project_bll.py | 111 +++++++++++++++++- apiserver/bll/task/task_bll.py | 26 +++- apiserver/database/model/task/task.py | 2 +- apiserver/mongo/migrations/0.17.1.py | 23 +++- apiserver/schema/services/events.conf | 4 +- apiserver/schema/services/models.conf | 52 ++++++-- apiserver/schema/services/tasks.conf | 40 ++++++- apiserver/services/events.py | 4 +- apiserver/services/models.py | 24 +++- apiserver/services/projects.py | 29 ++--- apiserver/services/tasks.py | 30 ++++- .../automated/test_move_under_project.py | 45 +++++++ .../tests/automated/test_task_debug_images.py | 2 +- 15 files changed, 351 insertions(+), 48 deletions(-) create mode 100644 apiserver/tests/automated/test_move_under_project.py diff --git a/apiserver/apimodels/base.py b/apiserver/apimodels/base.py index 8d7099d..ca89dcd 100644 --- a/apiserver/apimodels/base.py +++ b/apiserver/apimodels/base.py @@ -67,3 +67,9 @@ class IdResponse(models.Base): class MakePublicRequest(models.Base): ids = ListField(items_types=str, validators=[Length(minimum_value=1)]) + + +class MoveRequest(models.Base): + ids = ListField([str], validators=Length(minimum_value=1)) + project = fields.StringField() + project_name = fields.StringField() diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 7658018..27ca2a7 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -115,6 +115,7 @@ class CloneRequest(TaskRequest): new_configuration = DictField() execution_overrides = DictField() validate_references = BoolField(default=False) + new_project_name = StringField() class AddOrUpdateArtifactsRequest(TaskRequest): diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 120ed01..a5eb712 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -1,9 +1,13 @@ -from typing import Sequence, Optional +from datetime import datetime +from typing import Sequence, Optional, Type -from mongoengine import Q +from mongoengine import Q, Document +from apiserver import database +from apiserver.apierrors import errors from apiserver.config_repo import config from apiserver.database.model.model import Model +from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task from apiserver.timing_context import TimingContext @@ -31,3 +35,106 @@ class ProjectBLL: res |= set(cls_.objects(query).distinct(field="user")) return res + + @classmethod + def create( + cls, + user: str, + company: str, + name: str, + description: str, + tags: Sequence[str] = None, + system_tags: Sequence[str] = None, + default_output_destination: str = None, + ) -> str: + """ + Create a new project. + Returns project ID + """ + now = datetime.utcnow() + project = Project( + id=database.utils.id(), + user=user, + company=company, + name=name, + description=description, + tags=tags, + system_tags=system_tags, + default_output_destination=default_output_destination, + created=now, + last_update=now, + ) + project.save() + return project.id + + @classmethod + def find_or_create( + cls, + user: str, + company: str, + project_name: str, + description: str, + project_id: str = None, + tags: Sequence[str] = None, + system_tags: Sequence[str] = None, + default_output_destination: str = None, + ) -> str: + """ + Find a project named `project_name` or create a new one. + Returns project ID + """ + if not project_id and not project_name: + raise ValueError("project id or name required") + + if project_id: + project = Project.objects(company=company, id=project_id).only("id").first() + if not project: + raise errors.bad_request.InvalidProjectId(id=project_id) + return project_id + + project = Project.objects(company=company, name=project_name).only("id").first() + if project: + return project.id + + return cls.create( + user=user, + company=company, + name=project_name, + description=description, + tags=tags, + system_tags=system_tags, + default_output_destination=default_output_destination, + ) + + @classmethod + def move_under_project( + cls, + entity_cls: Type[Document], + user: str, + company: str, + ids: Sequence[str], + project: str = None, + project_name: str = None, + ): + """ + Move a batch of entities to `project` or a project named `project_name` (create if does not exist) + """ + with TimingContext("mongo", "move_under_project"): + project = cls.find_or_create( + user=user, + company=company, + project_id=project, + project_name=project_name, + description="Auto-generated during move", + ) + extra = ( + {"set__last_update": datetime.utcnow()} + if hasattr(entity_cls, "last_update") + else {} + ) + + entity_cls.objects(company=company, id__in=ids).update( + set__project=project, **extra + ) + + return project diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index b5afa81..93ff713 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -9,8 +9,9 @@ from six import string_types import apiserver.database.utils as dbutils from apiserver.apierrors import errors -from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.queue import QueueBLL +from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model @@ -36,9 +37,11 @@ from .utils import ChangeStatusRequest, validate_status_change log = config.logger(__file__) org_bll = OrgBLL() +queue_bll = QueueBLL() +project_bll = ProjectBLL() -class TaskBLL(object): +class TaskBLL: def __init__(self, events_es=None): self.events_es = ( events_es if events_es is not None else es_factory.connect("events") @@ -162,9 +165,9 @@ class TaskBLL(object): @classmethod def clone_task( cls, - company_id, - user_id, - task_id, + company_id: str, + user_id: str, + task_id: str, name: Optional[str] = None, comment: Optional[str] = None, parent: Optional[str] = None, @@ -175,6 +178,7 @@ class TaskBLL(object): configuration: Optional[dict] = None, execution_overrides: Optional[dict] = None, validate_references: bool = False, + new_project_name: str = None, ) -> Task: params_dict = { field: value @@ -210,6 +214,16 @@ class TaskBLL(object): for k, a in artifacts.items() if a.get("mode") != ArtifactModes.output } + + if not project and new_project_name: + # Use a project with the provided name, or create a new project + project = project_bll.find_or_create( + project_name=new_project_name, + user=user_id, + company=company_id, + description="Auto-generated while cloning", + ) + now = datetime.utcnow() with translate_errors_context(): @@ -675,7 +689,7 @@ class TaskBLL(object): ) return { - "removed": QueueBLL().remove_task( + "removed": queue_bll.remove_task( company_id=company_id, queue_id=task.execution.queue, task_id=task.id ) } diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 9c325d2..fb5c9d3 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -204,7 +204,7 @@ class Task(AttributedDocument): started = DateTimeField() completed = DateTimeField() published = DateTimeField() - active_duration = IntField(default=0) + active_duration = IntField(default=None) parent = StringField(reference_field="Task") project = StringField(reference_field=Project, user_set_allowed=True) output: Output = EmbeddedDocumentField(Output, default=Output) diff --git a/apiserver/mongo/migrations/0.17.1.py b/apiserver/mongo/migrations/0.17.1.py index 09edea5..6d936a8 100644 --- a/apiserver/mongo/migrations/0.17.1.py +++ b/apiserver/mongo/migrations/0.17.1.py @@ -1,3 +1,6 @@ +from datetime import datetime +from typing import Optional + from pymongo.database import Database @@ -6,17 +9,29 @@ def _add_active_duration(db: Database): query = {active_duration: {"$eq": None}} collection = db["task"] for doc in collection.find( - filter=query, projection=[active_duration, "started", "last_update"] + filter=query, projection=[active_duration, "status", "started", "completed"] ): started = doc.get("started") - last_update = doc.get("last_update") - if started and last_update and doc.get(active_duration) is None: + completed = doc.get("completed") + running = doc.get("status") == "running" + if started and doc.get(active_duration) is None: collection.update_one( {"_id": doc["_id"]}, - {"$set": {active_duration: (last_update - started).total_seconds()}}, + {"$set": {active_duration: _get_active_duration(completed, running, started)}}, ) +def _get_active_duration( + completed: datetime, running: bool, started: datetime +) -> Optional[float]: + if running: + return (datetime.utcnow() - started).total_seconds() + elif completed: + return (completed - started).total_seconds() + else: + return None + + def migrate_backend(db: Database): """ Add active_duration field to tasks diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 2f9faf4..8f2daca 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -417,7 +417,7 @@ } } get_debug_image_sample { - "2.11": { + "2.12": { description: "Return the debug image per metric and variant for the provided iteration" request { type: object @@ -449,7 +449,7 @@ } } next_debug_image_sample { - "2.11": { + "2.12": { description: "Get the image for the next variant for the same iteration or for the next iteration" request { type: object diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 1de82d9..6a6c1fd 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -55,13 +55,13 @@ _definitions { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -306,13 +306,13 @@ update_for_task { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } override_model_id { @@ -372,13 +372,13 @@ create { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -460,13 +460,13 @@ edit { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -542,13 +542,13 @@ update { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } ready { @@ -747,4 +747,34 @@ make_private { } } } -} \ No newline at end of file +} + +move { + "2.12" { + description: "Move models to a project" + request { + type: object + required: [ids] + properties { + ids { + description: "Models to move" + type: array + items { type: string } + } + project { + description: "Target project ID. If not provided, `project_name` must be provided." + type: string + } + project_name { + description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided." + type: string + } + } + } + response { + type: object + additionalProperties: true + } + } +} + diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index fe392fb..208ae0f 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -772,6 +772,16 @@ clone { } } } + "2.12": ${clone."2.5"} { + request { + properties { + new_project_name { + description: "Clone task to a new project by this name (only if `new_task_project` is not provided). If a project by this name already exists, task will be cloned to existing project." + type: string + } + } + } + } } create { "2.1" { @@ -1992,4 +2002,32 @@ delete_configuration { } } } -} \ No newline at end of file +} +move { + "2.12" { + description: "Move tasks to a project" + request { + type: object + required: [ids] + properties { + ids { + description: "Tasks to move" + type: array + items { type: string } + } + project { + description: "Target project ID. If not provided, `project_name` must be provided." + type: string + } + project_name { + description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided." + type: string + } + } + } + response { + type: object + additionalProperties: true + } + } +} diff --git a/apiserver/services/events.py b/apiserver/services/events.py index da08ac9..a6a565e 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -630,7 +630,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest): @endpoint( "events.get_debug_image_sample", - min_version="2.11", + min_version="2.12", request_data_model=GetDebugImageSampleRequest, ) def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest): @@ -650,7 +650,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest @endpoint( "events.next_debug_image_sample", - min_version="2.11", + min_version="2.12", request_data_model=NextDebugImageSampleRequest, ) def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest): diff --git a/apiserver/services/models.py b/apiserver/services/models.py index aa2535f..287c004 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -6,7 +6,7 @@ from mongoengine import Q, EmbeddedDocument from apiserver import database from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidModelId -from apiserver.apimodels.base import UpdateResponse, MakePublicRequest +from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest from apiserver.apimodels.models import ( CreateModelRequest, CreateModelResponse, @@ -17,6 +17,7 @@ from apiserver.apimodels.models import ( ) from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context @@ -36,6 +37,7 @@ from apiserver.timing_context import TimingContext log = config.logger(__file__) org_bll = OrgBLL() model_bll = ModelBLL() +project_bll = ProjectBLL() @endpoint("models.get_by_id", required_fields=["model"]) @@ -498,3 +500,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Model.set_public( company_id, request.ids, invalid_cls=InvalidModelId, enabled=False ) + + +@endpoint("models.move", request_data_model=MoveRequest) +def move(call: APICall, company_id: str, request: MoveRequest): + if not (request.project or request.project_name): + raise errors.bad_request.MissingRequiredFields( + "project or project_name is required" + ) + + with translate_errors_context(): + return { + "project_id": project_bll.move_under_project( + entity_cls=Model, + user=call.identity.user, + company=company_id, + ids=request.ids, + project=request.project, + project_name=request.project_name, + ) + } diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 3511369..4e5f8c7 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -6,16 +6,16 @@ from operator import itemgetter import dpath from mongoengine import Q -from apiserver import database from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidProjectId -from apiserver.apimodels.base import UpdateResponse, MakePublicRequest +from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( GetHyperParamReq, ProjectReq, ProjectTagsRequest, ) from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility @@ -280,26 +280,23 @@ def get_all(call: APICall): call.result.data = {"projects": projects} -@endpoint("projects.create", required_fields=["name", "description"]) -def create(call): - assert isinstance(call, APICall) +@endpoint( + "projects.create", + required_fields=["name", "description"], + response_data_model=IdResponse, +) +def create(call: APICall): identity = call.identity with translate_errors_context(): fields = parse_from_call(call.data, create_fields, Project.get_fields()) conform_tag_fields(call, fields, validate=True) - now = datetime.utcnow() - project = Project( - id=database.utils.id(), - user=identity.user, - company=identity.company, - created=now, - last_update=now, - **fields + + return IdResponse( + id=ProjectBLL.create( + user=identity.user, company=identity.company, **fields, + ) ) - with TimingContext("mongo", "projects_save"): - project.save() - call.result.data = {"id": project.id} @endpoint( diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 77d9eb7..c10d217 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -12,7 +12,12 @@ from pymongo import UpdateOne from apiserver.apierrors import errors, APIError from apiserver.apierrors.errors.bad_request import InvalidTaskId -from apiserver.apimodels.base import UpdateResponse, IdResponse, MakePublicRequest +from apiserver.apimodels.base import ( + UpdateResponse, + IdResponse, + MakePublicRequest, + MoveRequest, +) from apiserver.apimodels.tasks import ( StartedResponse, ResetResponse, @@ -44,6 +49,7 @@ from apiserver.apimodels.tasks import ( ) from apiserver.bll.event import EventBLL from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.queue import QueueBLL from apiserver.bll.task import ( TaskBLL, @@ -95,6 +101,7 @@ task_bll = TaskBLL() event_bll = EventBLL() queue_bll = QueueBLL() org_bll = OrgBLL() +project_bll = ProjectBLL() NonResponsiveTasksWatchdog.start() @@ -427,6 +434,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest): configuration=request.new_configuration, execution_overrides=request.execution_overrides, validate_references=request.validate_references, + new_project_name=request.new_project_name, ) call.result.data_model = IdResponse(id=task.id) @@ -1188,3 +1196,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Task.set_public( company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False ) + + +@endpoint("tasks.move", request_data_model=MoveRequest) +def move(call: APICall, company_id: str, request: MoveRequest): + if not (request.project or request.project_name): + raise errors.bad_request.MissingRequiredFields( + "project or project_name is required" + ) + + with translate_errors_context(): + return { + "project_id": project_bll.move_under_project( + entity_cls=Task, + user=call.identity.user, + company=company_id, + ids=request.ids, + project=request.project, + project_name=request.project_name, + ) + } diff --git a/apiserver/tests/automated/test_move_under_project.py b/apiserver/tests/automated/test_move_under_project.py new file mode 100644 index 0000000..456412b --- /dev/null +++ b/apiserver/tests/automated/test_move_under_project.py @@ -0,0 +1,45 @@ +from apiserver.tests.automated import TestService + + +class TestMoveUnderProject(TestService): + entity_name = "test move" + + def setUp(self, version="2.12"): + super().setUp(version=version) + + def test_move(self): + # task move into the new project + task = self._temp_task() + project = self.api.tasks.move(ids=[task], project_name=self.entity_name).project_id + tasks = self.api.tasks.get_all_ex(id=[task]).tasks + self.assertEqual(project, tasks[0].project.id) + projects = self.api.projects.get_all_ex(id=[project]).projects + self.assertEqual(self.entity_name, projects[0].name) + + # task clone + p2_name = "project_for_clone" + task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id + tasks = self.api.tasks.get_all_ex(id=[task2]).tasks + project2 = tasks[0].project.id + self.assertTrue(project2) + projects = self.api.projects.get_all_ex(id=[project2]).projects + self.assertEqual(p2_name, projects[0].name) + self.api.projects.delete(project=project2, force=True) + + # model move into existing project referenced by name + model = self._temp_model() + self.api.models.move(ids=[model], project_name=self.entity_name) + models = self.api.models.get_all_ex(id=[model]).models + self.assertEqual(project, models[0].project.id) + + self.api.projects.delete(project=project, force=True) + + def _temp_task(self): + task_input = dict( + name=self.entity_name, type="training", input=dict(mapping={}, view=dict(entries=[])), + ) + return self.create_temp("tasks", **task_input) + + def _temp_model(self): + model_input = dict(name=self.entity_name, uri="file:///a/b", labels={}) + return self.create_temp("models", **model_input) diff --git a/apiserver/tests/automated/test_task_debug_images.py b/apiserver/tests/automated/test_task_debug_images.py index eebc97c..a582078 100644 --- a/apiserver/tests/automated/test_task_debug_images.py +++ b/apiserver/tests/automated/test_task_debug_images.py @@ -7,7 +7,7 @@ from apiserver.tests.automated import TestService class TestTaskDebugImages(TestService): - def setUp(self, version="2.11"): + def setUp(self, version="2.12"): super().setUp(version=version) def _temp_task(self, name="test task events"):