from apiserver.apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId from apiserver.apierrors.errors.forbidden import NoWritePermission from apiserver.config_repo import config from apiserver.tests.api_client import APIError from apiserver.tests.automated import TestService log = config.logger(__file__) class TestTasksEdit(TestService): def setUp(self, **kwargs): super().setUp(version="2.12") def new_task(self, **kwargs): self.update_missing( kwargs, type="testing", name="test" ) return self.create_temp("tasks", **kwargs) def new_model(self, **kwargs): self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) return self.create_temp("models", **kwargs) def new_queue(self, **kwargs): self.update_missing(kwargs, name="test") return self.create_temp("queues", **kwargs) def test_task_types(self): with self.api.raises(ValidationError): task = self.new_task(type="Unsupported") types = ["controller", "optimizer"] p1 = self.create_temp("projects", name="Test tasks1", description="test") task1 = self.new_task(project=p1, type=types[0]) p2 = self.create_temp("projects", name="Test tasks2", description="test") task2 = self.new_task(project=p2, type=types[1]) # all company types res = self.api.tasks.get_types() self.assertTrue(set(types).issubset(set(res["types"]))) # projects array res = self.api.tasks.get_types(projects=[p1, p2]) self.assertEqual(set(types), set(res["types"])) # single project for p, t in zip((p1, p2), types): res = self.api.tasks.get_types(projects=[p]) self.assertEqual([t], res["types"]) def test_edit_model_ready(self): task = self.new_task() model = self.new_model() self.api.tasks.edit(task=task, execution=dict(model=model)) def test_edit_model_not_ready(self): task = self.new_task() model = self.new_model() self.api.models.edit(model=model, ready=False) self.assertFalse(self.api.models.get_by_id(model=model).model.ready) self.api.tasks.edit(task=task, execution=dict(model=model)) def test_edit_had_model_model_not_ready(self): ready_model = self.new_model() self.assert_(self.api.models.get_by_id(model=ready_model).model.ready) task = self.new_task(execution=dict(model=ready_model)) not_ready_model = self.new_model() self.api.models.edit(model=not_ready_model, ready=False) self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready) self.api.tasks.edit(task=task, execution=dict(model=not_ready_model)) def test_task_with_model_reset(self): # on task reset output model deleted task = self.new_task() self.api.tasks.started(task=task) model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"] self.api.tasks.reset(task=task) with self.api.raises(InvalidModelId): self.api.models.get_by_id(model=model_id) # unless it is input of some task task = self.new_task() self.api.tasks.started(task=task) model_id = self.api.models.update_for_task(task=task, uri="file:///b")["id"] task_2 = self.new_task(execution=dict(model=model_id)) self.api.tasks.reset(task=task) self.api.models.get_by_id(model=model_id) def test_clone_task(self): script = dict( binary="python", requirements=dict(pip=["six"]), repository="https://example.come/foo/bar", entry_point="test.py", diff="foo", ) execution = dict(parameters=dict(test="Test")) tags = ["hello"] system_tags = ["development", "test"] task = self.new_task( script=script, execution=execution, tags=tags, system_tags=system_tags ) new_name = "new test" new_tags = ["by"] execution_overrides = dict(framework="Caffe", model_labels={"test": 1.0}) new_task_id = self._clone_task( task=task, new_task_name=new_name, new_task_tags=new_tags, execution_overrides=execution_overrides, new_task_parent=task, ) new_task = self.api.tasks.get_by_id(task=new_task_id).task self.assertEqual(new_task.name, new_name) self.assertEqual(new_task.type, "testing") self.assertEqual(new_task.tags, new_tags) self.assertEqual(new_task.status, "created") self.assertEqual(new_task.script, script) self.assertEqual(new_task.parent, task) # self.assertEqual(new_task.execution.parameters, execution["parameters"]) self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) self.assertEqual(new_task.execution.model_labels, {"test": 1}) self.assertEqual(new_task.system_tags, ["test"]) def test_model_check_in_clone(self): model = self.new_model() task = self.new_task(execution=dict(model=model)) # task with deleted model still can be copied self.api.models.delete(model=model, force=True) self._clone_task(task=task, new_task_name="clone test") # unless check for refs is done with self.api.raises(InvalidModelId): self._clone_task( task=task, new_task_name="clone test2", validate_references=True ) # if the model is overriden then it is always checked with self.api.raises(InvalidModelId): self._clone_task( task=task, new_task_name="clone test3", execution_overrides=dict(model="not existing"), ) def _clone_task(self, task, **kwargs): new_task = self.api.tasks.clone(task=task, **kwargs).id self.defer( self.api.tasks.delete, task=new_task, move_to_trash=False, force=True ) return new_task def test_make_public(self): task = self.new_task() # task is created as private and can be updated self.api.tasks.started(task=task) # task with company_origin not set to the current company cannot be converted to private with self.api.raises(InvalidTaskId): self.api.tasks.make_private(ids=[task]) # public task can be retrieved but not updated res = self.api.tasks.make_public(ids=[task]) self.assertEqual(res.updated, 1) res = self.api.tasks.get_all_ex(id=[task]) self.assertEqual([t.id for t in res.tasks], [task]) with self.api.raises(NoWritePermission): self.api.tasks.stopped(task=task) # task made private again and can be both retrieved and updated res = self.api.tasks.make_private(ids=[task]) self.assertEqual(res.updated, 1) res = self.api.tasks.get_all_ex(id=[task]) self.assertEqual([t.id for t in res.tasks], [task]) self.api.tasks.stopped(task=task) def test_archive_task(self): # non-existing task throws an exception with self.assertRaises(APIError): self.api.tasks.archive(tasks=["fake-task-id"]) system_tag = "existing-system-tag" status_message = "test-status-message" status_reason = "test-status-reason" queue_id = self.new_queue() # Create two tasks with system_tags and enqueue one of them dequeued_task_id = self.new_task(system_tags=[system_tag]) enqueued_task_id = self.new_task(system_tags=[system_tag]) self.api.tasks.enqueue(task=enqueued_task_id, queue=queue_id) self.api.tasks.archive( tasks=[enqueued_task_id, dequeued_task_id], status_message=status_message, status_reason=status_reason, ) tasks = self.api.tasks.get_all_ex(id=[enqueued_task_id, dequeued_task_id]).tasks for task in tasks: self.assertIn(system_tag, task.system_tags) self.assertIn("archived", task.system_tags) self.assertIn(status_message, task.status_message) self.assertIn(status_reason, task.status_reason) # Check that the queue does not contain the enqueued task anymore queue = self.api.queues.get_by_id(queue=queue_id).queue task_in_queue = next( (True for entry in queue.entries if entry["task"] == enqueued_task_id), False, ) self.assertFalse(task_in_queue) def test_stopped_task_enqueue(self): queue_id = self.new_queue() task_id = self.new_task() self.api.tasks.started(task=task_id) self.api.tasks.stopped(task=task_id) projection = ["*", "execution.*"] task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0] self.assertEqual(task.status, "stopped") self.api.tasks.enqueue(task=task_id, queue=queue_id) task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0] self.assertEqual(task.status, "queued") self.assertEqual(task.execution.queue.id, queue_id)