Add support for additional task types as well as tasks.get_types to obtain actual types used globally or per project

This commit is contained in:
allegroai
2020-06-01 13:05:12 +03:00
parent 84a75d9e70
commit 9403942ef7
9 changed files with 111 additions and 15 deletions

View File

@@ -54,6 +54,10 @@ class TestService(TestCase, TestServiceInterface):
)
return object_id
@staticmethod
def update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def create_temp(self, service, *, client=None, delete_params=None, **kwargs) -> str:
return self._create_temp_helper(
service=service,

View File

@@ -208,25 +208,21 @@ class TestTags(TestService):
self.api.tasks.stopped(task=task_id)
def _temp_queue(self, **kwargs):
self._update_missing(kwargs, name="Test tags")
self.update_missing(kwargs, name="Test tags")
return self.create_temp("queues", **kwargs)
def _temp_project(self, **kwargs):
self._update_missing(kwargs, name="Test tags", description="test")
self.update_missing(kwargs, name="Test tags", description="test")
return self.create_temp("projects", **kwargs)
def _temp_model(self, **kwargs):
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
self.update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def _temp_task(self, **kwargs):
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
self.update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
return self.create_temp("tasks", **kwargs)
@staticmethod
def _update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def _send(self, service, action, **kwargs):
api = kwargs.pop("api", self.api)
return AttrDict(

View File

@@ -1,4 +1,4 @@
from apierrors.errors.bad_request import InvalidModelId
from apierrors.errors.bad_request import InvalidModelId, ValidationError
from config import config
from tests.automated import TestService
@@ -11,12 +11,37 @@ class TestTasksEdit(TestService):
super().setUp(version=2.5)
def new_task(self, **kwargs):
return self.create_temp(
"tasks", type="testing", name="test", input=dict(view=dict()), **kwargs
self.update_missing(
kwargs, type="testing", name="test", input=dict(view=dict())
)
return self.create_temp("tasks", **kwargs)
def new_model(self):
return self.create_temp("models", name="test", uri="file:///a/b", labels={})
def new_model(self, **kwargs):
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def test_task_types(self):
with self.api.raises(ValidationError):
task = self.new_task(type="Unsupported")
types = ["controller", "optimizer"]
p1 = self.create_temp("projects", name="Test tasks1", description="test")
task1 = self.new_task(project=p1, type=types[0])
p2 = self.create_temp("projects", name="Test tasks2", description="test")
task2 = self.new_task(project=p2, type=types[1])
# all company types
res = self.api.tasks.get_types()
self.assertTrue(set(types).issubset(set(res["types"])))
# projects array
res = self.api.tasks.get_types(projects=[p1, p2])
self.assertEqual(set(types), set(res["types"]))
# single project
for p, t in zip((p1, p2), types):
res = self.api.tasks.get_types(projects=[p])
self.assertEqual([t], res["types"])
def test_edit_model_ready(self):
task = self.new_task()