From 667964cc82c24b2283040f7e9d9c5f70ae4588aa Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 1 Jun 2020 13:07:35 +0300 Subject: [PATCH] Add clear_all flag to tasks.reset --- server/apimodels/tasks.py | 4 ++++ server/database/model/task/task.py | 4 ++-- server/schema/services/tasks.conf | 5 +++++ server/services/tasks.py | 26 +++++++++++++++++------ server/tests/automated/test_tasks_diff.py | 21 ++++++++++++------ 5 files changed, 45 insertions(+), 15 deletions(-) diff --git a/server/apimodels/tasks.py b/server/apimodels/tasks.py index bc95049..d864552 100644 --- a/server/apimodels/tasks.py +++ b/server/apimodels/tasks.py @@ -114,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest): class AddOrUpdateArtifactsResponse(models.Base): added = ListField([str]) updated = ListField([str]) + + +class ResetRequest(UpdateRequest): + clear_all = BoolField(default=False) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index bab1789..8601c8e 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -182,11 +182,11 @@ class Task(AttributedDocument): published = DateTimeField() parent = StringField() project = StringField(reference_field=Project, user_set_allowed=True) - output = EmbeddedDocumentField(Output, default=Output) + output: Output = EmbeddedDocumentField(Output, default=Output) execution: Execution = EmbeddedDocumentField(Execution, default=Execution) tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) - script = EmbeddedDocumentField(Script) + script: Script = EmbeddedDocumentField(Script) last_worker = StringField() last_worker_report = DateTimeField() last_update = DateTimeField() diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index b582eea..2f0d10b 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -943,6 +943,11 @@ reset { properties.force = ${_references.force_arg} { description: "If not true, call fails if the task status is 'completed'" } + properties.clear_all { + description: "Clear script and execution sections completely" + type: boolean + default: false + } } ${_references.status_change_request} response { type: object diff --git a/server/services/tasks.py b/server/services/tasks.py index 3da78c5..9cf0f8b 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -30,6 +30,7 @@ from apimodels.tasks import ( AddOrUpdateArtifactsRequest, AddOrUpdateArtifactsResponse, GetTypesRequest, + ResetRequest, ) from bll.event import EventBLL from bll.organization import OrgBLL @@ -669,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False): @endpoint( - "tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse + "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse ) -def reset(call: APICall, company_id, req_model: UpdateRequest): +def reset(call: APICall, company_id, request: ResetRequest): task = TaskBLL.get_task_with_access( - req_model.task, company_id=company_id, requires_write_access=True + request.task, company_id=company_id, requires_write_access=True ) - force = req_model.force + force = request.force if not force and task.status == TaskStatus.published: raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) @@ -692,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest): else: if dequeued: api_results.update(dequeued=dequeued) - updates.update(unset__execution__queue=1) cleaned_up = cleanup_task(task, force) api_results.update(attr.asdict(cleaned_up)) @@ -700,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest): updates.update( set__last_iteration=DEFAULT_LAST_ITERATION, set__last_metrics={}, + set__metric_stats={}, unset__output__result=1, unset__output__model=1, - __raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}}, + unset__output__error=1, + unset__last_worker=1, + unset__last_worker_report=1, ) + if request.clear_all: + updates.update( + set__execution=Execution(), + unset__script=1, + ) + else: + updates.update(unset__execution__queue=1) + updates.update( + __raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}}, + ) + res = ResetResponse( **ChangeStatusRequest( task=task, diff --git a/server/tests/automated/test_tasks_diff.py b/server/tests/automated/test_tasks_diff.py index a059d98..c4f9e88 100644 --- a/server/tests/automated/test_tasks_diff.py +++ b/server/tests/automated/test_tasks_diff.py @@ -14,9 +14,13 @@ class TestTasksDiff(TestService): "tasks", name="test", type="testing", input=dict(view=dict()), **kwargs ) - def _compare_script(self, task, script): - for key, value in script.items(): - self.assertEqual(task.script[key], value) + def _compare_script(self, task_id, script): + task = self.api.tasks.get_by_id(task=task_id).task + if not script: + self.assertFalse(task.get("script", None)) + else: + for key, value in script.items(): + self.assertEqual(task.script[key], value) def test_not_deleted(self): task_id = self.new_task() @@ -28,11 +32,14 @@ class TestTasksDiff(TestService): ) self.api.tasks.edit(task=task_id, script=script) self.api.tasks.started(task=task_id) + self.api.tasks.reset(task=task_id) - task = self.api.tasks.get_by_id(task=task_id).task - self._compare_script(task, script) + self._compare_script(task_id, script) + new_reqs = dict() self.api.tasks.set_requirements(task=task_id, requirements=new_reqs) script["requirements"] = new_reqs - task = self.api.tasks.get_by_id(task=task_id).task - self._compare_script(task, script) + self._compare_script(task_id, script) + + self.api.tasks.reset(task=task_id, clear_all=True) + self._compare_script(task_id, {})