Add clear_all flag to tasks.reset

This commit is contained in:
allegroai 2020-06-01 13:07:35 +03:00
parent e1309e30b7
commit 667964cc82
5 changed files with 45 additions and 15 deletions

View File

@ -114,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str])
updated = ListField([str])
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)

View File

@ -182,11 +182,11 @@ class Task(AttributedDocument):
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
output: Output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
script: Script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()

View File

@ -943,6 +943,11 @@ reset {
properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'"
}
properties.clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
} ${_references.status_change_request}
response {
type: object

View File

@ -30,6 +30,7 @@ from apimodels.tasks import (
AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL
@ -669,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
@endpoint(
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
)
def reset(call: APICall, company_id, req_model: UpdateRequest):
def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, requires_write_access=True
request.task, company_id=company_id, requires_write_access=True
)
force = req_model.force
force = request.force
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@ -692,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
else:
if dequeued:
api_results.update(dequeued=dequeued)
updates.update(unset__execution__queue=1)
cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up))
@ -700,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
unset__output__result=1,
unset__output__model=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
)
if request.clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
updates.update(
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
res = ResetResponse(
**ChangeStatusRequest(
task=task,

View File

@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
)
def _compare_script(self, task, script):
for key, value in script.items():
self.assertEqual(task.script[key], value)
def _compare_script(self, task_id, script):
task = self.api.tasks.get_by_id(task=task_id).task
if not script:
self.assertFalse(task.get("script", None))
else:
for key, value in script.items():
self.assertEqual(task.script[key], value)
def test_not_deleted(self):
task_id = self.new_task()
@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
)
self.api.tasks.edit(task=task_id, script=script)
self.api.tasks.started(task=task_id)
self.api.tasks.reset(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task
self._compare_script(task, script)
self._compare_script(task_id, script)
new_reqs = dict()
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
script["requirements"] = new_reqs
task = self.api.tasks.get_by_id(task=task_id).task
self._compare_script(task, script)
self._compare_script(task_id, script)
self.api.tasks.reset(task=task_id, clear_all=True)
self._compare_script(task_id, {})