Add clear_all flag to tasks.reset

This commit is contained in:
allegroai 2020-06-01 13:07:35 +03:00
parent e1309e30b7
commit 667964cc82
5 changed files with 45 additions and 15 deletions

View File

@ -114,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
class AddOrUpdateArtifactsResponse(models.Base): class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str]) added = ListField([str])
updated = ListField([str]) updated = ListField([str])
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)

View File

@ -182,11 +182,11 @@ class Task(AttributedDocument):
published = DateTimeField() published = DateTimeField()
parent = StringField() parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True) project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output) output: Output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution) execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True) system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script) script: Script = EmbeddedDocumentField(Script)
last_worker = StringField() last_worker = StringField()
last_worker_report = DateTimeField() last_worker_report = DateTimeField()
last_update = DateTimeField() last_update = DateTimeField()

View File

@ -943,6 +943,11 @@ reset {
properties.force = ${_references.force_arg} { properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is 'completed'" description: "If not true, call fails if the task status is 'completed'"
} }
properties.clear_all {
description: "Clear script and execution sections completely"
type: boolean
default: false
}
} ${_references.status_change_request} } ${_references.status_change_request}
response { response {
type: object type: object

View File

@ -30,6 +30,7 @@ from apimodels.tasks import (
AddOrUpdateArtifactsRequest, AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse, AddOrUpdateArtifactsResponse,
GetTypesRequest, GetTypesRequest,
ResetRequest,
) )
from bll.event import EventBLL from bll.event import EventBLL
from bll.organization import OrgBLL from bll.organization import OrgBLL
@ -669,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
@endpoint( @endpoint(
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse "tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
) )
def reset(call: APICall, company_id, req_model: UpdateRequest): def reset(call: APICall, company_id, request: ResetRequest):
task = TaskBLL.get_task_with_access( task = TaskBLL.get_task_with_access(
req_model.task, company_id=company_id, requires_write_access=True request.task, company_id=company_id, requires_write_access=True
) )
force = req_model.force force = request.force
if not force and task.status == TaskStatus.published: if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@ -692,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
else: else:
if dequeued: if dequeued:
api_results.update(dequeued=dequeued) api_results.update(dequeued=dequeued)
updates.update(unset__execution__queue=1)
cleaned_up = cleanup_task(task, force) cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up)) api_results.update(attr.asdict(cleaned_up))
@ -700,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
updates.update( updates.update(
set__last_iteration=DEFAULT_LAST_ITERATION, set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={}, set__last_metrics={},
set__metric_stats={},
unset__output__result=1, unset__output__result=1,
unset__output__model=1, unset__output__model=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}}, unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
) )
if request.clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
updates.update(
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
res = ResetResponse( res = ResetResponse(
**ChangeStatusRequest( **ChangeStatusRequest(
task=task, task=task,

View File

@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs "tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
) )
def _compare_script(self, task, script): def _compare_script(self, task_id, script):
for key, value in script.items(): task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.script[key], value) if not script:
self.assertFalse(task.get("script", None))
else:
for key, value in script.items():
self.assertEqual(task.script[key], value)
def test_not_deleted(self): def test_not_deleted(self):
task_id = self.new_task() task_id = self.new_task()
@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
) )
self.api.tasks.edit(task=task_id, script=script) self.api.tasks.edit(task=task_id, script=script)
self.api.tasks.started(task=task_id) self.api.tasks.started(task=task_id)
self.api.tasks.reset(task=task_id) self.api.tasks.reset(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task self._compare_script(task_id, script)
self._compare_script(task, script)
new_reqs = dict() new_reqs = dict()
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs) self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
script["requirements"] = new_reqs script["requirements"] = new_reqs
task = self.api.tasks.get_by_id(task=task_id).task self._compare_script(task_id, script)
self._compare_script(task, script)
self.api.tasks.reset(task=task_id, clear_all=True)
self._compare_script(task_id, {})