mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Add clear_all flag to tasks.reset
This commit is contained in:
parent
e1309e30b7
commit
667964cc82
@ -114,3 +114,7 @@ class AddOrUpdateArtifactsRequest(TaskRequest):
|
|||||||
class AddOrUpdateArtifactsResponse(models.Base):
|
class AddOrUpdateArtifactsResponse(models.Base):
|
||||||
added = ListField([str])
|
added = ListField([str])
|
||||||
updated = ListField([str])
|
updated = ListField([str])
|
||||||
|
|
||||||
|
|
||||||
|
class ResetRequest(UpdateRequest):
|
||||||
|
clear_all = BoolField(default=False)
|
||||||
|
@ -182,11 +182,11 @@ class Task(AttributedDocument):
|
|||||||
published = DateTimeField()
|
published = DateTimeField()
|
||||||
parent = StringField()
|
parent = StringField()
|
||||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||||
output = EmbeddedDocumentField(Output, default=Output)
|
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||||
script = EmbeddedDocumentField(Script)
|
script: Script = EmbeddedDocumentField(Script)
|
||||||
last_worker = StringField()
|
last_worker = StringField()
|
||||||
last_worker_report = DateTimeField()
|
last_worker_report = DateTimeField()
|
||||||
last_update = DateTimeField()
|
last_update = DateTimeField()
|
||||||
|
@ -943,6 +943,11 @@ reset {
|
|||||||
properties.force = ${_references.force_arg} {
|
properties.force = ${_references.force_arg} {
|
||||||
description: "If not true, call fails if the task status is 'completed'"
|
description: "If not true, call fails if the task status is 'completed'"
|
||||||
}
|
}
|
||||||
|
properties.clear_all {
|
||||||
|
description: "Clear script and execution sections completely"
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
} ${_references.status_change_request}
|
} ${_references.status_change_request}
|
||||||
response {
|
response {
|
||||||
type: object
|
type: object
|
||||||
|
@ -30,6 +30,7 @@ from apimodels.tasks import (
|
|||||||
AddOrUpdateArtifactsRequest,
|
AddOrUpdateArtifactsRequest,
|
||||||
AddOrUpdateArtifactsResponse,
|
AddOrUpdateArtifactsResponse,
|
||||||
GetTypesRequest,
|
GetTypesRequest,
|
||||||
|
ResetRequest,
|
||||||
)
|
)
|
||||||
from bll.event import EventBLL
|
from bll.event import EventBLL
|
||||||
from bll.organization import OrgBLL
|
from bll.organization import OrgBLL
|
||||||
@ -669,14 +670,14 @@ def _dequeue(task: Task, company_id: str, silent_fail=False):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
|
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
||||||
)
|
)
|
||||||
def reset(call: APICall, company_id, req_model: UpdateRequest):
|
def reset(call: APICall, company_id, request: ResetRequest):
|
||||||
task = TaskBLL.get_task_with_access(
|
task = TaskBLL.get_task_with_access(
|
||||||
req_model.task, company_id=company_id, requires_write_access=True
|
request.task, company_id=company_id, requires_write_access=True
|
||||||
)
|
)
|
||||||
|
|
||||||
force = req_model.force
|
force = request.force
|
||||||
|
|
||||||
if not force and task.status == TaskStatus.published:
|
if not force and task.status == TaskStatus.published:
|
||||||
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
||||||
@ -692,7 +693,6 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
|||||||
else:
|
else:
|
||||||
if dequeued:
|
if dequeued:
|
||||||
api_results.update(dequeued=dequeued)
|
api_results.update(dequeued=dequeued)
|
||||||
updates.update(unset__execution__queue=1)
|
|
||||||
|
|
||||||
cleaned_up = cleanup_task(task, force)
|
cleaned_up = cleanup_task(task, force)
|
||||||
api_results.update(attr.asdict(cleaned_up))
|
api_results.update(attr.asdict(cleaned_up))
|
||||||
@ -700,11 +700,25 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
|
|||||||
updates.update(
|
updates.update(
|
||||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||||
set__last_metrics={},
|
set__last_metrics={},
|
||||||
|
set__metric_stats={},
|
||||||
unset__output__result=1,
|
unset__output__result=1,
|
||||||
unset__output__model=1,
|
unset__output__model=1,
|
||||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
unset__output__error=1,
|
||||||
|
unset__last_worker=1,
|
||||||
|
unset__last_worker_report=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if request.clear_all:
|
||||||
|
updates.update(
|
||||||
|
set__execution=Execution(),
|
||||||
|
unset__script=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
updates.update(unset__execution__queue=1)
|
||||||
|
updates.update(
|
||||||
|
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||||
|
)
|
||||||
|
|
||||||
res = ResetResponse(
|
res = ResetResponse(
|
||||||
**ChangeStatusRequest(
|
**ChangeStatusRequest(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -14,9 +14,13 @@ class TestTasksDiff(TestService):
|
|||||||
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
|
"tasks", name="test", type="testing", input=dict(view=dict()), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
def _compare_script(self, task, script):
|
def _compare_script(self, task_id, script):
|
||||||
for key, value in script.items():
|
task = self.api.tasks.get_by_id(task=task_id).task
|
||||||
self.assertEqual(task.script[key], value)
|
if not script:
|
||||||
|
self.assertFalse(task.get("script", None))
|
||||||
|
else:
|
||||||
|
for key, value in script.items():
|
||||||
|
self.assertEqual(task.script[key], value)
|
||||||
|
|
||||||
def test_not_deleted(self):
|
def test_not_deleted(self):
|
||||||
task_id = self.new_task()
|
task_id = self.new_task()
|
||||||
@ -28,11 +32,14 @@ class TestTasksDiff(TestService):
|
|||||||
)
|
)
|
||||||
self.api.tasks.edit(task=task_id, script=script)
|
self.api.tasks.edit(task=task_id, script=script)
|
||||||
self.api.tasks.started(task=task_id)
|
self.api.tasks.started(task=task_id)
|
||||||
|
|
||||||
self.api.tasks.reset(task=task_id)
|
self.api.tasks.reset(task=task_id)
|
||||||
task = self.api.tasks.get_by_id(task=task_id).task
|
self._compare_script(task_id, script)
|
||||||
self._compare_script(task, script)
|
|
||||||
new_reqs = dict()
|
new_reqs = dict()
|
||||||
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
|
self.api.tasks.set_requirements(task=task_id, requirements=new_reqs)
|
||||||
script["requirements"] = new_reqs
|
script["requirements"] = new_reqs
|
||||||
task = self.api.tasks.get_by_id(task=task_id).task
|
self._compare_script(task_id, script)
|
||||||
self._compare_script(task, script)
|
|
||||||
|
self.api.tasks.reset(task=task_id, clear_all=True)
|
||||||
|
self._compare_script(task_id, {})
|
||||||
|
Loading…
Reference in New Issue
Block a user