mirror of
https://github.com/clearml/clearml-server
synced 2025-04-23 15:44:16 +00:00
Add clear_all flag to tasks.reset
This commit is contained in:
parent
e1309e30b7
commit
667964cc82
@ -114,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
|
@ -182,11 +182,11 @@ class Task(AttributedDocument):
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
script: Script = EmbeddedDocumentField(Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
|
@ -943,6 +943,11 @@ reset {
|
||||
properties.force = ${_references.force_arg} {
|
||||
description: "If not true, call fails if the task status is 'completed'"
|
||||
}
|
||||
properties.clear_all {
|
||||
description: "Clear script and execution sections completely"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
} ${_references.status_change_request}
|
||||
response {
|
||||
type: object
|
||||
|
@ -30,6 +30,7 @@ from apimodels.tasks import (
|
||||
AddOrUpdateArtifactsRequest,
|
||||
AddOrUpdateArtifactsResponse,
|
||||
GetTypesRequest,
|
||||
ResetRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL
|
||||
@ -669,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
|
||||
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||
)
|
||||
def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
def reset(call: APICall, company_id, request: ResetRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
req_model.task, company_id=company_id, requires_write_access=True
|
||||
request.task, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
force = req_model.force
|
||||
force = request.force
|
||||
|
||||
if not force and task.status == TaskStatus.published:
|
||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||
@ -692,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
else:
|
||||
if dequeued:
|
||||
api_results.update(dequeued=dequeued)
|
||||
updates.update(unset__execution__queue=1)
|
||||
|
||||
cleaned_up = cleanup_task(task, force)
|
||||
api_results.update(attr.asdict(cleaned_up))
|
||||
@ -700,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
||||
updates.update(
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__metric_stats={},
|
||||
unset__output__result=1,
|
||||
unset__output__model=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
)
|
||||
|
||||
if request.clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
updates.update(
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
|
||||
res = ResetResponse(
|
||||
**ChangeStatusRequest(
|
||||
task=task,
|
||||
|
@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
|
||||
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
|
||||
)
|
||||
|
||||
def _compare_script(self, task, script):
|
||||
for key, value in script.items():
|
||||
self.assertEqual(task.script[key], value)
|
||||
def _compare_script(self, task_id, script):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
if not script:
|
||||
self.assertFalse(task.get("script", None))
|
||||
else:
|
||||
for key, value in script.items():
|
||||
self.assertEqual(task.script[key], value)
|
||||
|
||||
def test_not_deleted(self):
|
||||
task_id = self.new_task()
|
||||
@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
|
||||
)
|
||||
self.api.tasks.edit(task=task_id, script=script)
|
||||
self.api.tasks.started(task=task_id)
|
||||
|
||||
self.api.tasks.reset(task=task_id)
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self._compare_script(task, script)
|
||||
self._compare_script(task_id, script)
|
||||
|
||||
new_reqs = dict()
|
||||
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
|
||||
script["requirements"] = new_reqs
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self._compare_script(task, script)
|
||||
self._compare_script(task_id, script)
|
||||
|
||||
self.api.tasks.reset(task=task_id, clear_all=True)
|
||||
self._compare_script(task_id, {})
|
||||
|
Loading…
Reference in New Issue
Block a user