Set default task active duration to None

Move endpoints to new API version
Support tasks.move and models.move for moving tasks and models into projects
Support new project name in tasks.clone
Improve task active duration migration
This commit is contained in:
allegroai 2021-01-05 18:05:44 +02:00
parent 8b0afd47a6
commit bca3a6e556
15 changed files with 351 additions and 48 deletions

View File

@ -67,3 +67,9 @@ class IdResponse(models.Base):
class MakePublicRequest(models.Base):
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
class MoveRequest(models.Base):
ids = ListField([str], validators=Length(minimum_value=1))
project = fields.StringField()
project_name = fields.StringField()

View File

@ -115,6 +115,7 @@ class CloneRequest(TaskRequest):
new_configuration = DictField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()
class AddOrUpdateArtifactsRequest(TaskRequest):

View File

@ -1,9 +1,13 @@
from typing import Sequence, Optional
from datetime import datetime
from typing import Sequence, Optional, Type
from mongoengine import Q
from mongoengine import Q, Document
from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
@ -31,3 +35,106 @@ class ProjectBLL:
res |= set(cls_.objects(query).distinct(field="user"))
return res
@classmethod
def create(
cls,
user: str,
company: str,
name: str,
description: str,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
) -> str:
"""
Create a new project.
Returns project ID
"""
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
user=user,
company=company,
name=name,
description=description,
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
created=now,
last_update=now,
)
project.save()
return project.id
@classmethod
def find_or_create(
cls,
user: str,
company: str,
project_name: str,
description: str,
project_id: str = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
) -> str:
"""
Find a project named `project_name` or create a new one.
Returns project ID
"""
if not project_id and not project_name:
raise ValueError("project id or name required")
if project_id:
project = Project.objects(company=company, id=project_id).only("id").first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
return project_id
project = Project.objects(company=company, name=project_name).only("id").first()
if project:
return project.id
return cls.create(
user=user,
company=company,
name=project_name,
description=description,
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
)
@classmethod
def move_under_project(
cls,
entity_cls: Type[Document],
user: str,
company: str,
ids: Sequence[str],
project: str = None,
project_name: str = None,
):
"""
Move a batch of entities to `project` or a project named `project_name` (create if does not exist)
"""
with TimingContext("mongo", "move_under_project"):
project = cls.find_or_create(
user=user,
company=company,
project_id=project,
project_name=project_name,
description="Auto-generated during move",
)
extra = (
{"set__last_update": datetime.utcnow()}
if hasattr(entity_cls, "last_update")
else {}
)
entity_cls.objects(company=company, id__in=ids).update(
set__project=project, **extra
)
return project

View File

@ -9,8 +9,9 @@ from six import string_types
import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@ -36,9 +37,11 @@ from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
org_bll = OrgBLL()
queue_bll = QueueBLL()
project_bll = ProjectBLL()
class TaskBLL(object):
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
@ -162,9 +165,9 @@ class TaskBLL(object):
@classmethod
def clone_task(
cls,
company_id,
user_id,
task_id,
company_id: str,
user_id: str,
task_id: str,
name: Optional[str] = None,
comment: Optional[str] = None,
parent: Optional[str] = None,
@ -175,6 +178,7 @@ class TaskBLL(object):
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Task:
params_dict = {
field: value
@ -210,6 +214,16 @@ class TaskBLL(object):
for k, a in artifacts.items()
if a.get("mode") != ArtifactModes.output
}
if not project and new_project_name:
# Use a project with the provided name, or create a new project
project = project_bll.find_or_create(
project_name=new_project_name,
user=user_id,
company=company_id,
description="Auto-generated while cloning",
)
now = datetime.utcnow()
with translate_errors_context():
@ -675,7 +689,7 @@ class TaskBLL(object):
)
return {
"removed": QueueBLL().remove_task(
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}

View File

@ -204,7 +204,7 @@ class Task(AttributedDocument):
started = DateTimeField()
completed = DateTimeField()
published = DateTimeField()
active_duration = IntField(default=0)
active_duration = IntField(default=None)
parent = StringField(reference_field="Task")
project = StringField(reference_field=Project, user_set_allowed=True)
output: Output = EmbeddedDocumentField(Output, default=Output)

View File

@ -1,3 +1,6 @@
from datetime import datetime
from typing import Optional
from pymongo.database import Database
@ -6,17 +9,29 @@ def _add_active_duration(db: Database):
query = {active_duration: {"$eq": None}}
collection = db["task"]
for doc in collection.find(
filter=query, projection=[active_duration, "started", "last_update"]
filter=query, projection=[active_duration, "status", "started", "completed"]
):
started = doc.get("started")
last_update = doc.get("last_update")
if started and last_update and doc.get(active_duration) is None:
completed = doc.get("completed")
running = doc.get("status") == "running"
if started and doc.get(active_duration) is None:
collection.update_one(
{"_id": doc["_id"]},
{"$set": {active_duration: (last_update - started).total_seconds()}},
{"$set": {active_duration: _get_active_duration(completed, running, started)}},
)
def _get_active_duration(
completed: datetime, running: bool, started: datetime
) -> Optional[float]:
if running:
return (datetime.utcnow() - started).total_seconds()
elif completed:
return (completed - started).total_seconds()
else:
return None
def migrate_backend(db: Database):
"""
Add active_duration field to tasks

View File

@ -417,7 +417,7 @@
}
}
get_debug_image_sample {
"2.11": {
"2.12": {
description: "Return the debug image per metric and variant for the provided iteration"
request {
type: object
@ -449,7 +449,7 @@
}
}
next_debug_image_sample {
"2.11": {
"2.12": {
description: "Get the image for the next variant for the same iteration or for the next iteration"
request {
type: object

View File

@ -55,13 +55,13 @@ _definitions {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@ -306,13 +306,13 @@ update_for_task {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
override_model_id {
@ -372,13 +372,13 @@ create {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@ -460,13 +460,13 @@ edit {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@ -542,13 +542,13 @@ update {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
ready {
@ -747,4 +747,34 @@ make_private {
}
}
}
}
}
move {
"2.12" {
description: "Move models to a project"
request {
type: object
required: [ids]
properties {
ids {
description: "Models to move"
type: array
items { type: string }
}
project {
description: "Target project ID. If not provided, `project_name` must be provided."
type: string
}
project_name {
description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided."
type: string
}
}
}
response {
type: object
additionalProperties: true
}
}
}

View File

@ -772,6 +772,16 @@ clone {
}
}
}
"2.12": ${clone."2.5"} {
request {
properties {
new_project_name {
description: "Clone task to a new project by this name (only if `new_task_project` is not provided). If a project by this name already exists, task will be cloned to existing project."
type: string
}
}
}
}
}
create {
"2.1" {
@ -1992,4 +2002,32 @@ delete_configuration {
}
}
}
}
}
move {
"2.12" {
description: "Move tasks to a project"
request {
type: object
required: [ids]
properties {
ids {
description: "Tasks to move"
type: array
items { type: string }
}
project {
description: "Target project ID. If not provided, `project_name` must be provided."
type: string
}
project_name {
description: "Target project name. If provided and a project with this name does not exist, a new project will be created. If not provided, `project` must be provided."
type: string
}
}
}
response {
type: object
additionalProperties: true
}
}
}

View File

@ -630,7 +630,7 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint(
"events.get_debug_image_sample",
min_version="2.11",
min_version="2.12",
request_data_model=GetDebugImageSampleRequest,
)
def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest):
@ -650,7 +650,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
@endpoint(
"events.next_debug_image_sample",
min_version="2.11",
min_version="2.12",
request_data_model=NextDebugImageSampleRequest,
)
def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest):

View File

@ -6,7 +6,7 @@ from mongoengine import Q, EmbeddedDocument
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
from apiserver.apimodels.models import (
CreateModelRequest,
CreateModelResponse,
@ -17,6 +17,7 @@ from apiserver.apimodels.models import (
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@ -36,6 +37,7 @@ from apiserver.timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
model_bll = ModelBLL()
project_bll = ProjectBLL()
@endpoint("models.get_by_id", required_fields=["model"])
@ -498,3 +500,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
)
@endpoint("models.move", request_data_model=MoveRequest)
def move(call: APICall, company_id: str, request: MoveRequest):
if not (request.project or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
with translate_errors_context():
return {
"project_id": project_bll.move_under_project(
entity_cls=Model,
user=call.identity.user,
company=company_id,
ids=request.ids,
project=request.project,
project_name=request.project_name,
)
}

View File

@ -6,16 +6,16 @@ from operator import itemgetter
import dpath
from mongoengine import Q
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamReq,
ProjectReq,
ProjectTagsRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
@ -280,26 +280,23 @@ def get_all(call: APICall):
call.result.data = {"projects": projects}
@endpoint("projects.create", required_fields=["name", "description"])
def create(call):
assert isinstance(call, APICall)
@endpoint(
"projects.create",
required_fields=["name", "description"],
response_data_model=IdResponse,
)
def create(call: APICall):
identity = call.identity
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields, validate=True)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
user=identity.user,
company=identity.company,
created=now,
last_update=now,
**fields
return IdResponse(
id=ProjectBLL.create(
user=identity.user, company=identity.company, **fields,
)
)
with TimingContext("mongo", "projects_save"):
project.save()
call.result.data = {"id": project.id}
@endpoint(

View File

@ -12,7 +12,12 @@ from pymongo import UpdateOne
from apiserver.apierrors import errors, APIError
from apiserver.apierrors.errors.bad_request import InvalidTaskId
from apiserver.apimodels.base import UpdateResponse, IdResponse, MakePublicRequest
from apiserver.apimodels.base import (
UpdateResponse,
IdResponse,
MakePublicRequest,
MoveRequest,
)
from apiserver.apimodels.tasks import (
StartedResponse,
ResetResponse,
@ -44,6 +49,7 @@ from apiserver.apimodels.tasks import (
)
from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
@ -95,6 +101,7 @@ task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()
org_bll = OrgBLL()
project_bll = ProjectBLL()
NonResponsiveTasksWatchdog.start()
@ -427,6 +434,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
configuration=request.new_configuration,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
new_project_name=request.new_project_name,
)
call.result.data_model = IdResponse(id=task.id)
@ -1188,3 +1196,23 @@ def make_public(call: APICall, company_id, request: MakePublicRequest):
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
@endpoint("tasks.move", request_data_model=MoveRequest)
def move(call: APICall, company_id: str, request: MoveRequest):
if not (request.project or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
with translate_errors_context():
return {
"project_id": project_bll.move_under_project(
entity_cls=Task,
user=call.identity.user,
company=company_id,
ids=request.ids,
project=request.project,
project_name=request.project_name,
)
}

View File

@ -0,0 +1,45 @@
from apiserver.tests.automated import TestService
class TestMoveUnderProject(TestService):
entity_name = "test move"
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_move(self):
# task move into the new project
task = self._temp_task()
project = self.api.tasks.move(ids=[task], project_name=self.entity_name).project_id
tasks = self.api.tasks.get_all_ex(id=[task]).tasks
self.assertEqual(project, tasks[0].project.id)
projects = self.api.projects.get_all_ex(id=[project]).projects
self.assertEqual(self.entity_name, projects[0].name)
# task clone
p2_name = "project_for_clone"
task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id
tasks = self.api.tasks.get_all_ex(id=[task2]).tasks
project2 = tasks[0].project.id
self.assertTrue(project2)
projects = self.api.projects.get_all_ex(id=[project2]).projects
self.assertEqual(p2_name, projects[0].name)
self.api.projects.delete(project=project2, force=True)
# model move into existing project referenced by name
model = self._temp_model()
self.api.models.move(ids=[model], project_name=self.entity_name)
models = self.api.models.get_all_ex(id=[model]).models
self.assertEqual(project, models[0].project.id)
self.api.projects.delete(project=project, force=True)
def _temp_task(self):
task_input = dict(
name=self.entity_name, type="training", input=dict(mapping={}, view=dict(entries=[])),
)
return self.create_temp("tasks", **task_input)
def _temp_model(self):
model_input = dict(name=self.entity_name, uri="file:///a/b", labels={})
return self.create_temp("models", **model_input)

View File

@ -7,7 +7,7 @@ from apiserver.tests.automated import TestService
class TestTaskDebugImages(TestService):
def setUp(self, version="2.11"):
def setUp(self, version="2.12"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):