Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints

This commit is contained in:
allegroai 2024-06-20 17:52:46 +03:00
parent cdc668e3c8
commit 2e19a18ee4
4 changed files with 236 additions and 115 deletions

View File

@ -101,6 +101,10 @@ class DequeueRequest(UpdateRequest):
new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
@ -112,6 +116,7 @@ class DeleteRequest(UpdateRequest):
return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest):
@ -264,6 +269,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest):
class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base):
@ -275,8 +281,17 @@ class TaskBatchRequest(BatchRequest):
status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest):
@ -297,6 +312,7 @@ class DeleteManyRequest(TaskBatchRequest):
delete_output_models = BoolField(default=True)
force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest):

View File

@ -22,7 +22,7 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION,
DEFAULT_LAST_ITERATION, TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@ -32,54 +32,79 @@ log = config.logger(__file__)
queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task(
task: Union[str, Task],
company_id: str,
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Deque and archive task
Return 1 if successful
"""
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str):
task = get_task_with_write_access(
task,
company_id=company_id,
identity=identity,
only=(
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
only=fields,
)
user_id = identity.user
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
def archive_task_core(task_: Task) -> int:
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update(
status_message=status_message,
status_reason=status_reason,
add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
archive_task_core(step)
return archive_task_core(task)
def unarchive_task(
@ -88,24 +113,36 @@ def unarchive_task(
identity: Identity,
status_message: str,
status_reason: str,
include_pipeline_steps: bool,
) -> int:
"""
Unarchive task. Return 1 if successful
"""
fields = ("id", "type")
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=("id",),
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
only=fields,
)
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task(
task_id: str,
@ -262,6 +299,7 @@ def delete_task(
status_message: str,
status_reason: str,
delete_external_artifacts: bool,
include_pipeline_steps: bool = False,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(
@ -280,36 +318,51 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
def delete_task_core(task_: Task, force_: bool):
try:
TaskBLL.dequeue_and_change_status(
task_,
company_id=company_id,
user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
company=company_id,
user=user_id,
task=task,
force=force,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task_.last_update = datetime.utcnow()
task_.save()
else:
task_.delete()
return res
task_ids = [task.id]
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()
move_tasks_to_trash(task_ids)
update_project_time(task.project)
return 1, task, cleanup_res
@ -465,6 +518,7 @@ def stop_task(
user_name: str,
status_reason: str,
force: bool,
include_pipeline_steps: bool = False,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
@ -475,19 +529,21 @@ def stop_task(
:return: updated task fields
"""
user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access(
task_id,
company_id=company_id,
identity=identity,
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
only=fields,
)
def is_run_by_worker(t: Task) -> bool:
@ -499,32 +555,41 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
def stop_task_core(task_: Task, force_: bool):
is_queued = task_.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task_.system_tags
or not is_run_by_worker(task_)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task_.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
user_id=user_id,
).execute()
return ChangeStatusRequest(
task=task_,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force_,
user_id=user_id,
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import (
ResetManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import (
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
StopRequest,
UnarchiveManyRequest,
ArchiveManyRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import (
update_task,
@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
"tasks.stop", response_data_model=UpdateResponse
)
def stop(call: APICall, company_id, req_model: UpdateRequest):
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
:summary: Stop a running task. Requires task status 'in_progress' and
@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
"""
call.result.data_model = UpdateResponse(
**stop_task(
task_id=req_model.task,
task_id=request.task,
company_id=company_id,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
)
)
@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@ -531,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest):
def validate(call: APICall, _, __: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix):
call.data.pop("parent")
@ -555,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
def create(call: APICall, company_id, _: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context():
@ -1087,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"project",
"system_tags",
"enqueue_status",
"type",
),
)
for task in tasks:
@ -1096,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
)
call.result.data_model = ArchiveResponse(archived=archived)
@ -1103,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint(
"tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
archive_task,
@ -1114,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@ -1125,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
unarchive_task,
@ -1136,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@ -1160,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@ -1182,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))

View File

@ -5,6 +5,45 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService):
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
steps = [
self.api.tasks.create(
name=f"Pipeline step {i}",
project=project,
type="training",
system_tags=["pipeline"],
parent=task
).id
for i in range(2)
]
ids = [task, *steps]
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), len(ids))
# stop
partial_ids = [task, steps[0]]
self.api.tasks.enqueue_many(ids=partial_ids)
res = self.api.tasks.get_all_ex(id=partial_ids, search_hidden=True)
self.assertTrue(t.stats == "in_progress" for t in res.tasks)
self.api.tasks.stop(task=task, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertTrue(t.stats == "created" for t in res.tasks)
# archive/unarchive
self.api.tasks.archive(tasks=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), 0)
self.api.tasks.unarchive_many(ids=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), len(ids))
# delete
self.api.tasks.delete(task=task, force=True, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), 0)
def test_delete_runs(self):
queue = self.api.queues.get_default().id
task_name = "pipelines test"