diff --git a/apiserver/apimodels/base.py b/apiserver/apimodels/base.py index 8d7099d..ca89dcd 100644 --- a/apiserver/apimodels/base.py +++ b/apiserver/apimodels/base.py @@ -67,3 +67,9 @@ class IdResponse(models.Base): class MakePublicRequest(models.Base): ids = ListField(items_types=str, validators=[Length(minimum_value=1)]) + + +class MoveRequest(models.Base): + ids = ListField([str], validators=Length(minimum_value=1)) + project = fields.StringField() + project_name = fields.StringField() diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 7658018..27ca2a7 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -115,6 +115,7 @@ class CloneRequest(TaskRequest): new_configuration = DictField() execution_overrides = DictField() validate_references = BoolField(default=False) + new_project_name = StringField() class AddOrUpdateArtifactsRequest(TaskRequest): diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 120ed01..a5eb712 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -1,9 +1,13 @@ -from typing import Sequence, Optional +from datetime import datetime +from typing import Sequence, Optional, Type -from mongoengine import Q +from mongoengine import Q, Document +from apiserver import database +from apiserver.apierrors import errors from apiserver.config_repo import config from apiserver.database.model.model import Model +from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task from apiserver.timing_context import TimingContext @@ -31,3 +35,106 @@ class ProjectBLL: res |= set(cls_.objects(query).distinct(field="user")) return res + + @classmethod + def create( + cls, + user: str, + company: str, + name: str, + description: str, + tags: Sequence[str] = None, + system_tags: Sequence[str] = None, + default_output_destination: str = None, + ) -> str: + """ + Create a new project. + Returns project ID + """ + now = datetime.utcnow() + project = Project( + id=database.utils.id(), + user=user, + company=company, + name=name, + description=description, + tags=tags, + system_tags=system_tags, + default_output_destination=default_output_destination, + created=now, + last_update=now, + ) + project.save() + return project.id + + @classmethod + def find_or_create( + cls, + user: str, + company: str, + project_name: str, + description: str, + project_id: str = None, + tags: Sequence[str] = None, + system_tags: Sequence[str] = None, + default_output_destination: str = None, + ) -> str: + """ + Find a project named `project_name` or create a new one. + Returns project ID + """ + if not project_id and not project_name: + raise ValueError("project id or name required") + + if project_id: + project = Project.objects(company=company, id=project_id).only("id").first() + if not project: + raise errors.bad_request.InvalidProjectId(id=project_id) + return project_id + + project = Project.objects(company=company, name=project_name).only("id").first() + if project: + return project.id + + return cls.create( + user=user, + company=company, + name=project_name, + description=description, + tags=tags, + system_tags=system_tags, + default_output_destination=default_output_destination, + ) + + @classmethod + def move_under_project( + cls, + entity_cls: Type[Document], + user: str, + company: str, + ids: Sequence[str], + project: str = None, + project_name: str = None, + ): + """ + Move a batch of entities to `project` or a project named `project_name` (create if does not exist) + """ + with TimingContext("mongo", "move_under_project"): + project = cls.find_or_create( + user=user, + company=company, + project_id=project, + project_name=project_name, + description="Auto-generated during move", + ) + extra = ( + {"set__last_update": datetime.utcnow()} + if hasattr(entity_cls, "last_update") + else {} + ) + + entity_cls.objects(company=company, id__in=ids).update( + set__project=project, **extra + ) + + return project diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index b5afa81..93ff713 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -9,8 +9,9 @@ from six import string_types import apiserver.database.utils as dbutils from apiserver.apierrors import errors -from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.queue import QueueBLL +from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model @@ -36,9 +37,11 @@ from .utils import ChangeStatusRequest, validate_status_change log = config.logger(__file__) org_bll = OrgBLL() +queue_bll = QueueBLL() +project_bll = ProjectBLL() -class TaskBLL(object): +class TaskBLL: def __init__(self, events_es=None): self.events_es = ( events_es if events_es is not None else es_factory.connect("events") @@ -162,9 +165,9 @@ class TaskBLL(object): @classmethod def clone_task( cls, - company_id, - user_id, - task_id, + company_id: str, + user_id: str, + task_id: str, name: Optional[str] = None, comment: Optional[str] = None, parent: Optional[str] = None, @@ -175,6 +178,7 @@ class TaskBLL(object): configuration: Optional[dict] = None, execution_overrides: Optional[dict] = None, validate_references: bool = False, + new_project_name: str = None, ) -> Task: params_dict = { field: value @@ -210,6 +214,16 @@ class TaskBLL(object): for k, a in artifacts.items() if a.get("mode") != ArtifactModes.output } + + if not project and new_project_name: + # Use a project with the provided name, or create a new project + project = project_bll.find_or_create( + project_name=new_project_name, + user=user_id, + company=company_id, + description="Auto-generated while cloning", + ) + now = datetime.utcnow() with translate_errors_context(): @@ -675,7 +689,7 @@ class TaskBLL(object): ) return { - "removed": QueueBLL().remove_task( + "removed": queue_bll.remove_task( company_id=company_id, queue_id=task.execution.queue, task_id=task.id ) } diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 9c325d2..fb5c9d3 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -204,7 +204,7 @@ class Task(AttributedDocument): started = DateTimeField() completed = DateTimeField() published = DateTimeField() - active_duration = IntField(default=0) + active_duration = IntField(default=None) parent = StringField(reference_field="Task") project = StringField(reference_field=Project, user_set_allowed=True) output: Output = EmbeddedDocumentField(Output, default=Output) diff --git a/apiserver/mongo/migrations/0.17.1.py b/apiserver/mongo/migrations/0.17.1.py index 09edea5..6d936a8 100644 --- a/apiserver/mongo/migrations/0.17.1.py +++ b/apiserver/mongo/migrations/0.17.1.py @@ -1,3 +1,6 @@ +from datetime import datetime +from typing import Optional + from pymongo.database import Database @@ -6,17 +9,29 @@ def _add_active_duration(db: Database): query = {active_duration: {"$eq": None}} collection = db["task"] for doc in collection.find( - filter=query, projection=[active_duration, "started", "last_update"] + filter=query, projection=[active_duration, "status", "started", "completed"] ): started = doc.get("started") - last_update = doc.get("last_update") - if started and last_update and doc.get(active_duration) is None: + completed = doc.get("completed") + running = doc.get("status") == "running" + if started and doc.get(active_duration) is None: collection.update_one( {"_id": doc["_id"]}, - {"$set": {active_duration: (last_update - started).total_seconds()}}, + {"$set": {active_duration: _get_active_duration(completed, running, started)}}, ) +def _get_active_duration( + completed: datetime, running: bool, started: datetime +) -> Optional[float]: + if running: + return (datetime.utcnow() - started).total_seconds() + elif completed: + return (completed - started).total_seconds() + else: + return None + + def migrate_backend(db: Database): """ Add active_duration field to tasks diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 2f9faf4..8f2daca 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -417,7 +417,7 @@ } } get_debug_image_sample { - "2.11": { + "2.12": { description: "Return the debug image per metric and variant for the provided iteration" request { type: object @@ -449,7 +449,7 @@ } } next_debug_image_sample { - "2.11": { + "2.12": { description: "Get the image for the next variant for the same iteration or for the next iteration" request { type: object diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 1de82d9..6a6c1fd 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -55,13 +55,13 @@ _definitions { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -306,13 +306,13 @@ update_for_task { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } override_model_id { @@ -372,13 +372,13 @@ create { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -460,13 +460,13 @@ edit { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } framework { @@ -542,13 +542,13 @@ update { type: string } tags { + description: "User-defined tags list" type: array - description: "User-defined tags" items { type: string } } system_tags { + description: "System tags list. This field is reserved for system use, please don't use it." type: array - description: "System tags. This field is reserved for system use, please don't use it." items {type: string} } ready { @@ -747,4 +747,34 @@ make_private { } } } -} \ No newline at end of file +} + +move { + "2.12" { + description: "Move models to a project" + request { + type: object + required: [ids] + properties { + ids { + description: "Models to move" + type: array + items { type: string } + } + project { + description: "Target project ID. If not provided, `project_name` must be provided." + type: string + } + project_name { + description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided." + type: string + } + } + } + response { + type: object + additionalProperties: true + } + } +} + diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index fe392fb..208ae0f 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -772,6 +772,16 @@ clone { } } } + "2.12": ${clone."2.5"} { + request { + properties { + new_project_name { + description: "Clone task to a new project by this name (only if `new_task_project` is not provided). If a project by this name already exists, task will be cloned to existing project." + type: string + } + } + } + } } create { "2.1" { @@ -1992,4 +2002,32 @@ delete_configuration { } } } -} \ No newline at end of file +} +move { + "2.12" { + description: "Move tasks to a project" + request { + type: object + required: [ids] + properties { + ids { + description: "Tasks to move" + type: array + items { type: string } + } + project { + description: "Target project ID. If not provided, `project_name` must be provided." + type: string + } + project_name { + description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided." + type: string + } + } + } + response { + type: object + additionalProperties: true + } + } +} diff --git a/apiserver/services/events.py b/apiserver/services/events.py index da08ac9..a6a565e 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -630,7 +630,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest): @endpoint( "events.get_debug_image_sample", - min_version="2.11", + min_version="2.12", request_data_model=GetDebugImageSampleRequest, ) def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest): @@ -650,7 +650,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest @endpoint( "events.next_debug_image_sample", - min_version="2.11", + min_version="2.12", request_data_model=NextDebugImageSampleRequest, ) def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest): diff --git a/apiserver/services/models.py b/apiserver/services/models.py index aa2535f..287c004 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -6,7 +6,7 @@ from mongoengine import Q, EmbeddedDocument from apiserver import database from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidModelId -from apiserver.apimodels.base import UpdateResponse, MakePublicRequest +from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest from apiserver.apimodels.models import ( CreateModelRequest, CreateModelResponse, @@ -17,6 +17,7 @@ from apiserver.apimodels.models import ( ) from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context @@ -36,6 +37,7 @@ from apiserver.timing_context import TimingContext log = config.logger(__file__) org_bll = OrgBLL() model_bll = ModelBLL() +project_bll = ProjectBLL() @endpoint("models.get_by_id", required_fields=["model"]) @@ -498,3 +500,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Model.set_public( company_id, request.ids, invalid_cls=InvalidModelId, enabled=False ) + + +@endpoint("models.move", request_data_model=MoveRequest) +def move(call: APICall, company_id: str, request: MoveRequest): + if not (request.project or request.project_name): + raise errors.bad_request.MissingRequiredFields( + "project or project_name is required" + ) + + with translate_errors_context(): + return { + "project_id": project_bll.move_under_project( + entity_cls=Model, + user=call.identity.user, + company=company_id, + ids=request.ids, + project=request.project, + project_name=request.project_name, + ) + } diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 3511369..4e5f8c7 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -6,16 +6,16 @@ from operator import itemgetter import dpath from mongoengine import Q -from apiserver import database from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidProjectId -from apiserver.apimodels.base import UpdateResponse, MakePublicRequest +from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( GetHyperParamReq, ProjectReq, ProjectTagsRequest, ) from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility @@ -280,26 +280,23 @@ def get_all(call: APICall): call.result.data = {"projects": projects} -@endpoint("projects.create", required_fields=["name", "description"]) -def create(call): - assert isinstance(call, APICall) +@endpoint( + "projects.create", + required_fields=["name", "description"], + response_data_model=IdResponse, +) +def create(call: APICall): identity = call.identity with translate_errors_context(): fields = parse_from_call(call.data, create_fields, Project.get_fields()) conform_tag_fields(call, fields, validate=True) - now = datetime.utcnow() - project = Project( - id=database.utils.id(), - user=identity.user, - company=identity.company, - created=now, - last_update=now, - **fields + + return IdResponse( + id=ProjectBLL.create( + user=identity.user, company=identity.company, **fields, + ) ) - with TimingContext("mongo", "projects_save"): - project.save() - call.result.data = {"id": project.id} @endpoint( diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 77d9eb7..c10d217 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -12,7 +12,12 @@ from pymongo import UpdateOne from apiserver.apierrors import errors, APIError from apiserver.apierrors.errors.bad_request import InvalidTaskId -from apiserver.apimodels.base import UpdateResponse, IdResponse, MakePublicRequest +from apiserver.apimodels.base import ( + UpdateResponse, + IdResponse, + MakePublicRequest, + MoveRequest, +) from apiserver.apimodels.tasks import ( StartedResponse, ResetResponse, @@ -44,6 +49,7 @@ from apiserver.apimodels.tasks import ( ) from apiserver.bll.event import EventBLL from apiserver.bll.organization import OrgBLL, Tags +from apiserver.bll.project import ProjectBLL from apiserver.bll.queue import QueueBLL from apiserver.bll.task import ( TaskBLL, @@ -95,6 +101,7 @@ task_bll = TaskBLL() event_bll = EventBLL() queue_bll = QueueBLL() org_bll = OrgBLL() +project_bll = ProjectBLL() NonResponsiveTasksWatchdog.start() @@ -427,6 +434,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest): configuration=request.new_configuration, execution_overrides=request.execution_overrides, validate_references=request.validate_references, + new_project_name=request.new_project_name, ) call.result.data_model = IdResponse(id=task.id) @@ -1188,3 +1196,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Task.set_public( company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False ) + + +@endpoint("tasks.move", request_data_model=MoveRequest) +def move(call: APICall, company_id: str, request: MoveRequest): + if not (request.project or request.project_name): + raise errors.bad_request.MissingRequiredFields( + "project or project_name is required" + ) + + with translate_errors_context(): + return { + "project_id": project_bll.move_under_project( + entity_cls=Task, + user=call.identity.user, + company=company_id, + ids=request.ids, + project=request.project, + project_name=request.project_name, + ) + } diff --git a/apiserver/tests/automated/test_move_under_project.py b/apiserver/tests/automated/test_move_under_project.py new file mode 100644 index 0000000..456412b --- /dev/null +++ b/apiserver/tests/automated/test_move_under_project.py @@ -0,0 +1,45 @@ +from apiserver.tests.automated import TestService + + +class TestMoveUnderProject(TestService): + entity_name = "test move" + + def setUp(self, version="2.12"): + super().setUp(version=version) + + def test_move(self): + # task move into the new project + task = self._temp_task() + project = self.api.tasks.move(ids=[task], project_name=self.entity_name).project_id + tasks = self.api.tasks.get_all_ex(id=[task]).tasks + self.assertEqual(project, tasks[0].project.id) + projects = self.api.projects.get_all_ex(id=[project]).projects + self.assertEqual(self.entity_name, projects[0].name) + + # task clone + p2_name = "project_for_clone" + task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id + tasks = self.api.tasks.get_all_ex(id=[task2]).tasks + project2 = tasks[0].project.id + self.assertTrue(project2) + projects = self.api.projects.get_all_ex(id=[project2]).projects + self.assertEqual(p2_name, projects[0].name) + self.api.projects.delete(project=project2, force=True) + + # model move into existing project referenced by name + model = self._temp_model() + self.api.models.move(ids=[model], project_name=self.entity_name) + models = self.api.models.get_all_ex(id=[model]).models + self.assertEqual(project, models[0].project.id) + + self.api.projects.delete(project=project, force=True) + + def _temp_task(self): + task_input = dict( + name=self.entity_name, type="training", input=dict(mapping={}, view=dict(entries=[])), + ) + return self.create_temp("tasks", **task_input) + + def _temp_model(self): + model_input = dict(name=self.entity_name, uri="file:///a/b", labels={}) + return self.create_temp("models", **model_input) diff --git a/apiserver/tests/automated/test_task_debug_images.py b/apiserver/tests/automated/test_task_debug_images.py index eebc97c..a582078 100644 --- a/apiserver/tests/automated/test_task_debug_images.py +++ b/apiserver/tests/automated/test_task_debug_images.py @@ -7,7 +7,7 @@ from apiserver.tests.automated import TestService class TestTaskDebugImages(TestService): - def setUp(self, version="2.11"): + def setUp(self, version="2.12"): super().setUp(version=version) def _temp_task(self, name="test task events"):