From 9403942ef7e17983fe1c230b58d010bdf68ceb5a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 1 Jun 2020 13:05:12 +0300 Subject: [PATCH] Add support for additional task types as well as tasks.get_types to obtain actual types used globally or per project --- server/apimodels/tasks.py | 4 +++ server/bll/task/task_bll.py | 13 +++++++++ server/database/model/task/task.py | 12 ++++++++ server/schema/services/tasks.conf | 34 ++++++++++++++++++++++ server/services/projects.py | 4 +-- server/services/tasks.py | 8 ++++++ server/tests/automated/__init__.py | 4 +++ server/tests/automated/test_tags.py | 12 +++----- server/tests/automated/test_tasks_edit.py | 35 +++++++++++++++++++---- 9 files changed, 111 insertions(+), 15 deletions(-) diff --git a/server/apimodels/tasks.py b/server/apimodels/tasks.py index 604bcaa..bc95049 100644 --- a/server/apimodels/tasks.py +++ b/server/apimodels/tasks.py @@ -92,6 +92,10 @@ class PingRequest(TaskRequest): pass +class GetTypesRequest(models.Base): + projects = ListField(items_types=[str]) + + class CloneRequest(TaskRequest): new_task_name = StringField() new_task_comment = StringField() diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index a65b466..d682aa3 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -28,6 +28,7 @@ from database.model.task.task import ( TaskSystemTags, ArtifactModes, Artifact, + external_task_types, ) from database.utils import get_company_or_none_constraint, id as create_id from service_repo import APICall @@ -46,6 +47,18 @@ class TaskBLL(object): events_es if events_es is not None else es_factory.connect("events") ) + @classmethod + def get_types(cls, company, project_ids: Optional[Sequence]) -> set: + """ + Return the list of unique task types used by company and public tasks + If project ids passed then only tasks from these projects are considered + """ + query = get_company_or_none_constraint(company) + if project_ids: + query &= Q(project__in=project_ids) + res = Task.objects(query).distinct(field="type") + return set(res).intersection(external_task_types) + @staticmethod def get_task_with_access( task_id, company_id, only=None, allow_public=False, requires_write_access=False diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 49c1c69..bab1789 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -100,6 +100,18 @@ class Execution(EmbeddedDocument, ProperDictMixin): class TaskType(object): training = "training" testing = "testing" + inference = "inference" + data_processing = "data_processing" + application = "application" + monitor = "monitor" + controller = "controller" + optimizer = "optimizer" + service = "service" + qc = "qc" + custom = "custom" + + +external_task_types = set(get_options(TaskType)) class Task(AttributedDocument): diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index dbacc3b..b582eea 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -254,6 +254,15 @@ _definitions { enum: [ training testing + inference + data_processing + application + monitor + controller + optimizer + service + qc + custom ] } last_metrics_event { @@ -554,6 +563,31 @@ get_all { } } } +get_types { + "2.8" { + description: "Get the list of task types used in the specified projects" + request { + type: object + properties { + projects { + description: "The list of projects which tasks will be analyzed. If not passed or empty then all the company and public tasks will be analyzed" + type: array + items: {type: string} + } + } + } + response { + type: object + properties { + types { + description: "Unique list of the task types used in the requested projects" + type: array + items: {type: string} + } + } + } + } +} clone { "2.5" { description: "Clone an existing task" diff --git a/server/services/projects.py b/server/services/projects.py index b049b38..b952290 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -154,9 +154,9 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None # only count run time for these types of tasks { "$match": { - "type": {"$in": ["training", "testing", "annotation"]}, - "project": {"$in": project_ids}, + "type": {"$in": ["training", "testing"]}, "company": {"$in": [None, "", company_id]}, + "project": {"$in": project_ids}, } }, ensure_valid_fields(), diff --git a/server/services/tasks.py b/server/services/tasks.py index 1a14537..3da78c5 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -29,6 +29,7 @@ from apimodels.tasks import ( CloneRequest, AddOrUpdateArtifactsRequest, AddOrUpdateArtifactsResponse, + GetTypesRequest, ) from bll.event import EventBLL from bll.organization import OrgBLL @@ -164,6 +165,13 @@ def get_all(call: APICall, company_id, _): call.result.data = {"tasks": tasks} +@endpoint("tasks.get_types", request_data_model=GetTypesRequest) +def get_types(call: APICall, company_id, request: GetTypesRequest): + call.result.data = { + "types": list(task_bll.get_types(company_id, project_ids=request.projects)) + } + + @endpoint( "tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse ) diff --git a/server/tests/automated/__init__.py b/server/tests/automated/__init__.py index cc72ac1..d6b8f8d 100644 --- a/server/tests/automated/__init__.py +++ b/server/tests/automated/__init__.py @@ -54,6 +54,10 @@ class TestService(TestCase, TestServiceInterface): ) return object_id + @staticmethod + def update_missing(target: dict, **update): + target.update({k: v for k, v in update.items() if k not in target}) + def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str: return self._create_temp_helper( service=service, diff --git a/server/tests/automated/test_tags.py b/server/tests/automated/test_tags.py index bc20bc2..889b802 100644 --- a/server/tests/automated/test_tags.py +++ b/server/tests/automated/test_tags.py @@ -208,25 +208,21 @@ class TestTags(TestService): self.api.tasks.stopped(task=task_id) def _temp_queue(self, **kwargs): - self._update_missing(kwargs, name="Test tags") + self.update_missing(kwargs, name="Test tags") return self.create_temp("queues", **kwargs) def _temp_project(self, **kwargs): - self._update_missing(kwargs, name="Test tags", description="test") + self.update_missing(kwargs, name="Test tags", description="test") return self.create_temp("projects", **kwargs) def _temp_model(self, **kwargs): - self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={}) + self.update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={}) return self.create_temp("models", **kwargs) def _temp_task(self, **kwargs): - self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict())) + self.update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict())) return self.create_temp("tasks", **kwargs) - @staticmethod - def _update_missing(target: dict, **update): - target.update({k: v for k, v in update.items() if k not in target}) - def _send(self, service, action, **kwargs): api = kwargs.pop("api", self.api) return AttrDict( diff --git a/server/tests/automated/test_tasks_edit.py b/server/tests/automated/test_tasks_edit.py index 5e9cf18..0819f43 100644 --- a/server/tests/automated/test_tasks_edit.py +++ b/server/tests/automated/test_tasks_edit.py @@ -1,4 +1,4 @@ -from apierrors.errors.bad_request import InvalidModelId +from apierrors.errors.bad_request import InvalidModelId, ValidationError from config import config from tests.automated import TestService @@ -11,12 +11,37 @@ class TestTasksEdit(TestService): super().setUp(version=2.5) def new_task(self, **kwargs): - return self.create_temp( - "tasks", type="testing", name="test", input=dict(view=dict()), **kwargs + self.update_missing( + kwargs, type="testing", name="test", input=dict(view=dict()) ) + return self.create_temp("tasks", **kwargs) - def new_model(self): - return self.create_temp("models", name="test", uri="file:///a/b", labels={}) + def new_model(self, **kwargs): + self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) + return self.create_temp("models", **kwargs) + + def test_task_types(self): + with self.api.raises(ValidationError): + task = self.new_task(type="Unsupported") + + types = ["controller", "optimizer"] + p1 = self.create_temp("projects", name="Test tasks1", description="test") + task1 = self.new_task(project=p1, type=types[0]) + p2 = self.create_temp("projects", name="Test tasks2", description="test") + task2 = self.new_task(project=p2, type=types[1]) + + # all company types + res = self.api.tasks.get_types() + self.assertTrue(set(types).issubset(set(res["types"]))) + + # projects array + res = self.api.tasks.get_types(projects=[p1, p2]) + self.assertEqual(set(types), set(res["types"])) + + # single project + for p, t in zip((p1, p2), types): + res = self.api.tasks.get_types(projects=[p]) + self.assertEqual([t], res["types"]) def test_edit_model_ready(self): task = self.new_task()