Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints

This commit is contained in:
allegroai 2024-06-20 17:52:46 +03:00
parent cdc668e3c8
commit 2e19a18ee4
4 changed files with 236 additions and 115 deletions

View File

@ -101,6 +101,10 @@ class DequeueRequest(UpdateRequest):
new_status = StringField() new_status = StringField()
class StopRequest(UpdateRequest):
include_pipeline_steps = BoolField(default=False)
class EnqueueRequest(UpdateRequest): class EnqueueRequest(UpdateRequest):
queue = StringField() queue = StringField()
queue_name = StringField() queue_name = StringField()
@ -112,6 +116,7 @@ class DeleteRequest(UpdateRequest):
return_file_urls = BoolField(default=False) return_file_urls = BoolField(default=False)
delete_output_models = BoolField(default=True) delete_output_models = BoolField(default=True)
delete_external_artifacts = BoolField(default=True) delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class SetRequirementsRequest(TaskRequest): class SetRequirementsRequest(TaskRequest):
@ -264,6 +269,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest):
class ArchiveRequest(MultiTaskRequest): class ArchiveRequest(MultiTaskRequest):
status_reason = StringField(default="") status_reason = StringField(default="")
status_message = StringField(default="") status_message = StringField(default="")
include_pipeline_steps = BoolField(default=False)
class ArchiveResponse(models.Base): class ArchiveResponse(models.Base):
@ -275,8 +281,17 @@ class TaskBatchRequest(BatchRequest):
status_message = StringField(default="") status_message = StringField(default="")
class ArchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class UnarchiveManyRequest(TaskBatchRequest):
include_pipeline_steps = BoolField(default=False)
class StopManyRequest(TaskBatchRequest): class StopManyRequest(TaskBatchRequest):
force = BoolField(default=False) force = BoolField(default=False)
include_pipeline_steps = BoolField(default=False)
class DequeueManyRequest(TaskBatchRequest): class DequeueManyRequest(TaskBatchRequest):
@ -297,6 +312,7 @@ class DeleteManyRequest(TaskBatchRequest):
delete_output_models = BoolField(default=True) delete_output_models = BoolField(default=True)
force = BoolField(default=False) force = BoolField(default=False)
delete_external_artifacts = BoolField(default=True) delete_external_artifacts = BoolField(default=True)
include_pipeline_steps = BoolField(default=False)
class ResetManyRequest(TaskBatchRequest): class ResetManyRequest(TaskBatchRequest):

View File

@ -22,7 +22,7 @@ from apiserver.database.model.task.task import (
TaskStatusMessage, TaskStatusMessage,
ArtifactModes, ArtifactModes,
Execution, Execution,
DEFAULT_LAST_ITERATION, DEFAULT_LAST_ITERATION, TaskType,
) )
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity from apiserver.service_repo.auth import Identity
@ -32,54 +32,79 @@ log = config.logger(__file__)
queue_bll = QueueBLL() queue_bll = QueueBLL()
def _get_pipeline_steps_for_controller_task(
task: Task, company_id: str, only: Sequence[str] = None
) -> Sequence[Task]:
if not task or task.type != TaskType.controller:
return []
query = Task.objects(company=company_id, parent=task.id)
if only:
query = query.only(*only)
return list(query)
def archive_task( def archive_task(
task: Union[str, Task], task: Union[str, Task],
company_id: str, company_id: str,
identity: Identity, identity: Identity,
status_message: str, status_message: str,
status_reason: str, status_reason: str,
include_pipeline_steps: bool,
) -> int: ) -> int:
""" """
Deque and archive task Deque and archive task
Return 1 if successful Return 1 if successful
""" """
user_id = identity.user
fields = (
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
"type",
)
if isinstance(task, str): if isinstance(task, str):
task = get_task_with_write_access( task = get_task_with_write_access(
task, task,
company_id=company_id, company_id=company_id,
identity=identity, identity=identity,
only=( only=fields,
"id",
"company",
"execution",
"status",
"project",
"system_tags",
"enqueue_status",
),
) )
user_id = identity.user def archive_task_core(task_: Task) -> int:
try: try:
TaskBLL.dequeue_and_change_status( TaskBLL.dequeue_and_change_status(
task, task_,
company_id=company_id, company_id=company_id,
user_id=user_id, user_id=user_id,
status_message=status_message,
status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task_.update(
status_message=status_message, status_message=status_message,
status_reason=status_reason, status_reason=status_reason,
remove_from_all_queues=True, add_to_set__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=user_id,
) )
except APIError:
# dequeue may fail if the task was not enqueued
pass
return task.update( if include_pipeline_steps and (
status_message=status_message, step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
status_reason=status_reason, ):
add_to_set__system_tags=EntityVisibility.archived.value, for step in step_tasks:
last_change=datetime.utcnow(), archive_task_core(step)
last_changed_by=user_id,
) return archive_task_core(task)
def unarchive_task( def unarchive_task(
@ -88,24 +113,36 @@ def unarchive_task(
identity: Identity, identity: Identity,
status_message: str, status_message: str,
status_reason: str, status_reason: str,
include_pipeline_steps: bool,
) -> int: ) -> int:
""" """
Unarchive task. Return 1 if successful Unarchive task. Return 1 if successful
""" """
fields = ("id", "type")
task = get_task_with_write_access( task = get_task_with_write_access(
task_id, task_id,
company_id=company_id, company_id=company_id,
identity=identity, identity=identity,
only=("id",), only=fields,
)
return task.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
) )
def unarchive_task_core(task_: Task) -> int:
return task_.update(
status_message=status_message,
status_reason=status_reason,
pull__system_tags=EntityVisibility.archived.value,
last_change=datetime.utcnow(),
last_changed_by=identity.user,
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
unarchive_task_core(step)
return unarchive_task_core(task)
def dequeue_task( def dequeue_task(
task_id: str, task_id: str,
@ -262,6 +299,7 @@ def delete_task(
status_message: str, status_message: str,
status_reason: str, status_reason: str,
delete_external_artifacts: bool, delete_external_artifacts: bool,
include_pipeline_steps: bool = False,
) -> Tuple[int, Task, CleanupResult]: ) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user user_id = identity.user
task = get_task_with_write_access( task = get_task_with_write_access(
@ -280,36 +318,51 @@ def delete_task(
current=task.status, current=task.status,
) )
try: def delete_task_core(task_: Task, force_: bool):
TaskBLL.dequeue_and_change_status( try:
task, TaskBLL.dequeue_and_change_status(
company_id=company_id, task_,
user_id=user_id, company_id=company_id,
status_message=status_message, user_id=user_id,
status_reason=status_reason, status_message=status_message,
remove_from_all_queues=True, status_reason=status_reason,
remove_from_all_queues=True,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
res = cleanup_task(
company=company_id,
user=user_id,
task=task_,
force=force_,
return_file_urls=return_file_urls,
delete_output_models=delete_output_models,
delete_external_artifacts=delete_external_artifacts,
) )
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task( if move_to_trash:
company=company_id, # make sure that whatever changes were done to the task are saved
user=user_id, # the task itself will be deleted later in the move_tasks_to_trash operation
task=task, task_.last_update = datetime.utcnow()
force=force, task_.save()
return_file_urls=return_file_urls, else:
delete_output_models=delete_output_models, task_.delete()
delete_external_artifacts=delete_external_artifacts,
)
return res
task_ids = [task.id]
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id)
):
for step in step_tasks:
delete_task_core(step, True)
task_ids.append(step.id)
cleanup_res = delete_task_core(task, force)
if move_to_trash: if move_to_trash:
# make sure that whatever changes were done to the task are saved move_tasks_to_trash(task_ids)
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()
update_project_time(task.project) update_project_time(task.project)
return 1, task, cleanup_res return 1, task, cleanup_res
@ -465,6 +518,7 @@ def stop_task(
user_name: str, user_name: str,
status_reason: str, status_reason: str,
force: bool, force: bool,
include_pipeline_steps: bool = False,
) -> dict: ) -> dict:
""" """
Stop a running task. Requires task status 'in_progress' and Stop a running task. Requires task status 'in_progress' and
@ -475,19 +529,21 @@ def stop_task(
:return: updated task fields :return: updated task fields
""" """
user_id = identity.user user_id = identity.user
fields = (
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
"type",
)
task = get_task_with_write_access( task = get_task_with_write_access(
task_id, task_id,
company_id=company_id, company_id=company_id,
identity=identity, identity=identity,
only=( only=fields,
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
) )
def is_run_by_worker(t: Task) -> bool: def is_run_by_worker(t: Task) -> bool:
@ -499,32 +555,41 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
) )
is_queued = task.status == TaskStatus.queued def stop_task_core(task_: Task, force_: bool):
set_stopped = ( is_queued = task_.status == TaskStatus.queued
is_queued set_stopped = (
or TaskSystemTags.development in task.system_tags is_queued
or not is_run_by_worker(task) or TaskSystemTags.development in task_.system_tags
) or not is_run_by_worker(task_)
)
if set_stopped: if set_stopped:
if is_queued: if is_queued:
try: try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True) TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
pass pass
new_status = TaskStatus.stopped new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}" status_message = f"Stopped by {user_name}"
else: else:
new_status = task.status new_status = task_.status
status_message = TaskStatusMessage.stopping status_message = TaskStatusMessage.stopping
return ChangeStatusRequest( return ChangeStatusRequest(
task=task, task=task_,
new_status=new_status, new_status=new_status,
status_reason=status_reason, status_reason=status_reason,
status_message=status_message, status_message=status_message,
force=force, force=force_,
user_id=user_id, user_id=user_id,
).execute() ).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
):
for step in step_tasks:
stop_task_core(step, True)
return stop_task_core(task, force)

View File

@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import (
ResetManyRequest, ResetManyRequest,
DeleteManyRequest, DeleteManyRequest,
PublishManyRequest, PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse, EnqueueManyResponse,
EnqueueBatchItem, EnqueueBatchItem,
DequeueBatchItem, DequeueBatchItem,
@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import (
DequeueRequest, DequeueRequest,
DequeueManyRequest, DequeueManyRequest,
UpdateTagsRequest, UpdateTagsRequest,
StopRequest,
UnarchiveManyRequest,
ArchiveManyRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import (
delete_task, delete_task,
publish_task, publish_task,
unarchive_task, unarchive_task,
move_tasks_to_trash,
) )
from apiserver.bll.task.utils import ( from apiserver.bll.task.utils import (
update_task, update_task,
@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
@endpoint( @endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse "tasks.stop", response_data_model=UpdateResponse
) )
def stop(call: APICall, company_id, req_model: UpdateRequest): def stop(call: APICall, company_id, request: StopRequest):
""" """
stop stop
:summary: Stop a running task. Requires task status 'in_progress' and :summary: Stop a running task. Requires task status 'in_progress' and
@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
""" """
call.result.data_model = UpdateResponse( call.result.data_model = UpdateResponse(
**stop_task( **stop_task(
task_id=req_model.task, task_id=request.task,
company_id=company_id, company_id=company_id,
identity=call.identity, identity=call.identity,
user_name=call.identity.user_name, user_name=call.identity.user_name,
status_reason=req_model.status_reason, status_reason=request.status_reason,
force=req_model.force, force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
) )
) )
@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
user_name=call.identity.user_name, user_name=call.identity.user_name,
status_reason=request.status_reason, status_reason=request.status_reason,
force=request.force, force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
), ),
ids=request.ids, ids=request.ids,
) )
@ -531,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest) @endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest): def validate(call: APICall, _, __: CreateRequest):
parent = call.data.get("parent") parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix): if parent and parent.startswith(deleted_prefix):
call.data.pop("parent") call.data.pop("parent")
@ -555,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint( @endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse "tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
) )
def create(call: APICall, company_id, req_model: CreateRequest): def create(call: APICall, company_id, _: CreateRequest):
task, fields = _validate_and_get_task_from_call(call) task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(): with translate_errors_context():
@ -1087,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"project", "project",
"system_tags", "system_tags",
"enqueue_status", "enqueue_status",
"type",
), ),
) )
for task in tasks: for task in tasks:
@ -1096,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
task=task, task=task,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
) )
call.result.data_model = ArchiveResponse(archived=archived) call.result.data_model = ArchiveResponse(archived=archived)
@ -1103,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint( @endpoint(
"tasks.archive_many", "tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def archive_many(call: APICall, company_id, request: TaskBatchRequest): def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
results, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
archive_task, archive_task,
@ -1114,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity, identity=call.identity,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
), ),
ids=request.ids, ids=request.ids,
) )
@ -1125,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
@endpoint( @endpoint(
"tasks.unarchive_many", "tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse, response_data_model=BatchResponse,
) )
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest):
results, failures = run_batch_operation( results, failures = run_batch_operation(
func=partial( func=partial(
unarchive_task, unarchive_task,
@ -1136,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity, identity=call.identity,
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
), ),
ids=request.ids, ids=request.ids,
) )
@ -1160,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts, delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
) )
if deleted: if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else []) _reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@ -1182,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts, delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
), ),
ids=request.ids, ids=request.ids,
) )
if results: if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results) projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects)) _reset_cached_tags(company_id, projects=list(projects))

View File

@ -5,6 +5,45 @@ from apiserver.tests.automated import TestService
class TestPipelines(TestService): class TestPipelines(TestService):
def test_controller_operations(self):
task_name = "pipelines test"
project, task = self._temp_project_and_task(name=task_name)
steps = [
self.api.tasks.create(
name=f"Pipeline step {i}",
project=project,
type="training",
system_tags=["pipeline"],
parent=task
).id
for i in range(2)
]
ids = [task, *steps]
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), len(ids))
# stop
partial_ids = [task, steps[0]]
self.api.tasks.enqueue_many(ids=partial_ids)
res = self.api.tasks.get_all_ex(id=partial_ids, search_hidden=True)
self.assertTrue(t.stats == "in_progress" for t in res.tasks)
self.api.tasks.stop(task=task, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertTrue(t.stats == "created" for t in res.tasks)
# archive/unarchive
self.api.tasks.archive(tasks=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), 0)
self.api.tasks.unarchive_many(ids=[task], include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"])
self.assertEqual(len(res.tasks), len(ids))
# delete
self.api.tasks.delete(task=task, force=True, include_pipeline_steps=True)
res = self.api.tasks.get_all_ex(id=ids, search_hidden=True)
self.assertEqual(len(res.tasks), 0)
def test_delete_runs(self): def test_delete_runs(self):
queue = self.api.queues.get_default().id queue = self.api.queues.get_default().id
task_name = "pipelines test" task_name = "pipelines test"