diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index bb30dac..50a07d8 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -101,6 +101,10 @@ class DequeueRequest(UpdateRequest): new_status = StringField() +class StopRequest(UpdateRequest): + include_pipeline_steps = BoolField(default=False) + + class EnqueueRequest(UpdateRequest): queue = StringField() queue_name = StringField() @@ -112,6 +116,7 @@ class DeleteRequest(UpdateRequest): return_file_urls = BoolField(default=False) delete_output_models = BoolField(default=True) delete_external_artifacts = BoolField(default=True) + include_pipeline_steps = BoolField(default=False) class SetRequirementsRequest(TaskRequest): @@ -264,6 +269,7 @@ class DeleteConfigurationRequest(TaskUpdateRequest): class ArchiveRequest(MultiTaskRequest): status_reason = StringField(default="") status_message = StringField(default="") + include_pipeline_steps = BoolField(default=False) class ArchiveResponse(models.Base): @@ -275,8 +281,17 @@ class TaskBatchRequest(BatchRequest): status_message = StringField(default="") +class ArchiveManyRequest(TaskBatchRequest): + include_pipeline_steps = BoolField(default=False) + + +class UnarchiveManyRequest(TaskBatchRequest): + include_pipeline_steps = BoolField(default=False) + + class StopManyRequest(TaskBatchRequest): force = BoolField(default=False) + include_pipeline_steps = BoolField(default=False) class DequeueManyRequest(TaskBatchRequest): @@ -297,6 +312,7 @@ class DeleteManyRequest(TaskBatchRequest): delete_output_models = BoolField(default=True) force = BoolField(default=False) delete_external_artifacts = BoolField(default=True) + include_pipeline_steps = BoolField(default=False) class ResetManyRequest(TaskBatchRequest): diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index 6b65161..d4a86e8 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -22,7 +22,7 @@ from apiserver.database.model.task.task import ( TaskStatusMessage, ArtifactModes, Execution, - DEFAULT_LAST_ITERATION, + DEFAULT_LAST_ITERATION, TaskType, ) from apiserver.database.utils import get_options from apiserver.service_repo.auth import Identity @@ -32,54 +32,79 @@ log = config.logger(__file__) queue_bll = QueueBLL() +def _get_pipeline_steps_for_controller_task( + task: Task, company_id: str, only: Sequence[str] = None +) -> Sequence[Task]: + if not task or task.type != TaskType.controller: + return [] + + query = Task.objects(company=company_id, parent=task.id) + if only: + query = query.only(*only) + + return list(query) + + def archive_task( task: Union[str, Task], company_id: str, identity: Identity, status_message: str, status_reason: str, + include_pipeline_steps: bool, ) -> int: """ Deque and archive task Return 1 if successful """ + user_id = identity.user + fields = ( + "id", + "company", + "execution", + "status", + "project", + "system_tags", + "enqueue_status", + "type", + ) if isinstance(task, str): task = get_task_with_write_access( task, company_id=company_id, identity=identity, - only=( - "id", - "company", - "execution", - "status", - "project", - "system_tags", - "enqueue_status", - ), + only=fields, ) - user_id = identity.user - try: - TaskBLL.dequeue_and_change_status( - task, - company_id=company_id, - user_id=user_id, + def archive_task_core(task_: Task) -> int: + try: + TaskBLL.dequeue_and_change_status( + task_, + company_id=company_id, + user_id=user_id, + status_message=status_message, + status_reason=status_reason, + remove_from_all_queues=True, + ) + except APIError: + # dequeue may fail if the task was not enqueued + pass + + return task_.update( status_message=status_message, status_reason=status_reason, - remove_from_all_queues=True, + add_to_set__system_tags=EntityVisibility.archived.value, + last_change=datetime.utcnow(), + last_changed_by=user_id, ) - except APIError: - # dequeue may fail if the task was not enqueued - pass - return task.update( - status_message=status_message, - status_reason=status_reason, - add_to_set__system_tags=EntityVisibility.archived.value, - last_change=datetime.utcnow(), - last_changed_by=user_id, - ) + if include_pipeline_steps and ( + step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + ): + for step in step_tasks: + archive_task_core(step) + + return archive_task_core(task) def unarchive_task( @@ -88,24 +113,36 @@ def unarchive_task( identity: Identity, status_message: str, status_reason: str, + include_pipeline_steps: bool, ) -> int: """ Unarchive task. Return 1 if successful """ + fields = ("id", "type") task = get_task_with_write_access( task_id, company_id=company_id, identity=identity, - only=("id",), - ) - return task.update( - status_message=status_message, - status_reason=status_reason, - pull__system_tags=EntityVisibility.archived.value, - last_change=datetime.utcnow(), - last_changed_by=identity.user, + only=fields, ) + def unarchive_task_core(task_: Task) -> int: + return task_.update( + status_message=status_message, + status_reason=status_reason, + pull__system_tags=EntityVisibility.archived.value, + last_change=datetime.utcnow(), + last_changed_by=identity.user, + ) + + if include_pipeline_steps and ( + step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + ): + for step in step_tasks: + unarchive_task_core(step) + + return unarchive_task_core(task) + def dequeue_task( task_id: str, @@ -262,6 +299,7 @@ def delete_task( status_message: str, status_reason: str, delete_external_artifacts: bool, + include_pipeline_steps: bool = False, ) -> Tuple[int, Task, CleanupResult]: user_id = identity.user task = get_task_with_write_access( @@ -280,36 +318,51 @@ def delete_task( current=task.status, ) - try: - TaskBLL.dequeue_and_change_status( - task, - company_id=company_id, - user_id=user_id, - status_message=status_message, - status_reason=status_reason, - remove_from_all_queues=True, + def delete_task_core(task_: Task, force_: bool): + try: + TaskBLL.dequeue_and_change_status( + task_, + company_id=company_id, + user_id=user_id, + status_message=status_message, + status_reason=status_reason, + remove_from_all_queues=True, + ) + except APIError: + # dequeue may fail if the task was not enqueued + pass + + res = cleanup_task( + company=company_id, + user=user_id, + task=task_, + force=force_, + return_file_urls=return_file_urls, + delete_output_models=delete_output_models, + delete_external_artifacts=delete_external_artifacts, ) - except APIError: - # dequeue may fail if the task was not enqueued - pass - cleanup_res = cleanup_task( - company=company_id, - user=user_id, - task=task, - force=force, - return_file_urls=return_file_urls, - delete_output_models=delete_output_models, - delete_external_artifacts=delete_external_artifacts, - ) + if move_to_trash: + # make sure that whatever changes were done to the task are saved + # the task itself will be deleted later in the move_tasks_to_trash operation + task_.last_update = datetime.utcnow() + task_.save() + else: + task_.delete() + return res + + task_ids = [task.id] + if include_pipeline_steps and ( + step_tasks := _get_pipeline_steps_for_controller_task(task, company_id) + ): + for step in step_tasks: + delete_task_core(step, True) + task_ids.append(step.id) + + cleanup_res = delete_task_core(task, force) if move_to_trash: - # make sure that whatever changes were done to the task are saved - # the task itself will be deleted later in the move_tasks_to_trash operation - task.last_update = datetime.utcnow() - task.save() - else: - task.delete() + move_tasks_to_trash(task_ids) update_project_time(task.project) return 1, task, cleanup_res @@ -465,6 +518,7 @@ def stop_task( user_name: str, status_reason: str, force: bool, + include_pipeline_steps: bool = False, ) -> dict: """ Stop a running task. Requires task status 'in_progress' and @@ -475,19 +529,21 @@ def stop_task( :return: updated task fields """ user_id = identity.user + fields = ( + "status", + "project", + "tags", + "system_tags", + "last_worker", + "last_update", + "execution.queue", + "type", + ) task = get_task_with_write_access( task_id, company_id=company_id, identity=identity, - only=( - "status", - "project", - "tags", - "system_tags", - "last_worker", - "last_update", - "execution.queue", - ), + only=fields, ) def is_run_by_worker(t: Task) -> bool: @@ -499,32 +555,41 @@ def stop_task( and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout ) - is_queued = task.status == TaskStatus.queued - set_stopped = ( - is_queued - or TaskSystemTags.development in task.system_tags - or not is_run_by_worker(task) - ) + def stop_task_core(task_: Task, force_: bool): + is_queued = task_.status == TaskStatus.queued + set_stopped = ( + is_queued + or TaskSystemTags.development in task_.system_tags + or not is_run_by_worker(task_) + ) - if set_stopped: - if is_queued: - try: - TaskBLL.dequeue(task, company_id=company_id, silent_fail=True) - except APIError: - # dequeue may fail if the task was not enqueued - pass + if set_stopped: + if is_queued: + try: + TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True) + except APIError: + # dequeue may fail if the task was not enqueued + pass - new_status = TaskStatus.stopped - status_message = f"Stopped by {user_name}" - else: - new_status = task.status - status_message = TaskStatusMessage.stopping + new_status = TaskStatus.stopped + status_message = f"Stopped by {user_name}" + else: + new_status = task_.status + status_message = TaskStatusMessage.stopping - return ChangeStatusRequest( - task=task, - new_status=new_status, - status_reason=status_reason, - status_message=status_message, - force=force, - user_id=user_id, - ).execute() + return ChangeStatusRequest( + task=task_, + new_status=new_status, + status_reason=status_reason, + status_message=status_message, + force=force_, + user_id=user_id, + ).execute() + + if include_pipeline_steps and ( + step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + ): + for step in step_tasks: + stop_task_core(step, True) + + return stop_task_core(task, force) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 540c772..7d02662 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import ( ResetManyRequest, DeleteManyRequest, PublishManyRequest, - TaskBatchRequest, EnqueueManyResponse, EnqueueBatchItem, DequeueBatchItem, @@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import ( DequeueRequest, DequeueManyRequest, UpdateTagsRequest, + StopRequest, + UnarchiveManyRequest, + ArchiveManyRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.model import ModelBLL @@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import ( delete_task, publish_task, unarchive_task, - move_tasks_to_trash, ) from apiserver.bll.task.utils import ( update_task, @@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest): @endpoint( - "tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse + "tasks.stop", response_data_model=UpdateResponse ) -def stop(call: APICall, company_id, req_model: UpdateRequest): +def stop(call: APICall, company_id, request: StopRequest): """ stop :summary: Stop a running task. Requires task status 'in_progress' and @@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest): """ call.result.data_model = UpdateResponse( **stop_task( - task_id=req_model.task, + task_id=request.task, company_id=company_id, identity=call.identity, user_name=call.identity.user_name, - status_reason=req_model.status_reason, - force=req_model.force, + status_reason=request.status_reason, + force=request.force, + include_pipeline_steps=request.include_pipeline_steps, ) ) @@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest): user_name=call.identity.user_name, status_reason=request.status_reason, force=request.force, + include_pipeline_steps=request.include_pipeline_steps, ), ids=request.ids, ) @@ -531,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic @endpoint("tasks.validate", request_data_model=CreateRequest) -def validate(call: APICall, company_id, req_model: CreateRequest): +def validate(call: APICall, _, __: CreateRequest): parent = call.data.get("parent") if parent and parent.startswith(deleted_prefix): call.data.pop("parent") @@ -555,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]): @endpoint( "tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse ) -def create(call: APICall, company_id, req_model: CreateRequest): +def create(call: APICall, company_id, _: CreateRequest): task, fields = _validate_and_get_task_from_call(call) with translate_errors_context(): @@ -1087,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): "project", "system_tags", "enqueue_status", + "type", ), ) for task in tasks: @@ -1096,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): task=task, status_message=request.status_message, status_reason=request.status_reason, + include_pipeline_steps=request.include_pipeline_steps, ) call.result.data_model = ArchiveResponse(archived=archived) @@ -1103,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest): @endpoint( "tasks.archive_many", - request_data_model=TaskBatchRequest, response_data_model=BatchResponse, ) -def archive_many(call: APICall, company_id, request: TaskBatchRequest): +def archive_many(call: APICall, company_id, request: ArchiveManyRequest): results, failures = run_batch_operation( func=partial( archive_task, @@ -1114,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, + include_pipeline_steps=request.include_pipeline_steps, ), ids=request.ids, ) @@ -1125,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest): @endpoint( "tasks.unarchive_many", - request_data_model=TaskBatchRequest, response_data_model=BatchResponse, ) -def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): +def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest): results, failures = run_batch_operation( func=partial( unarchive_task, @@ -1136,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest): identity=call.identity, status_message=request.status_message, status_reason=request.status_reason, + include_pipeline_steps=request.include_pipeline_steps, ), ids=request.ids, ) @@ -1160,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest): status_message=request.status_message, status_reason=request.status_reason, delete_external_artifacts=request.delete_external_artifacts, + include_pipeline_steps=request.include_pipeline_steps, ) if deleted: - if request.move_to_trash: - move_tasks_to_trash([request.task]) _reset_cached_tags(company_id, projects=[task.project] if task.project else []) call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res)) @@ -1182,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest): status_message=request.status_message, status_reason=request.status_reason, delete_external_artifacts=request.delete_external_artifacts, + include_pipeline_steps=request.include_pipeline_steps, ), ids=request.ids, ) if results: - if request.move_to_trash: - task_ids = set(task.id for _, (_, task, _) in results) - if task_ids: - move_tasks_to_trash(list(task_ids)) projects = set(task.project for _, (_, task, _) in results) _reset_cached_tags(company_id, projects=list(projects)) diff --git a/apiserver/tests/automated/test_pipelines.py b/apiserver/tests/automated/test_pipelines.py index ac4144a..3c1e0f9 100644 --- a/apiserver/tests/automated/test_pipelines.py +++ b/apiserver/tests/automated/test_pipelines.py @@ -5,6 +5,45 @@ from apiserver.tests.automated import TestService class TestPipelines(TestService): + def test_controller_operations(self): + task_name = "pipelines test" + project, task = self._temp_project_and_task(name=task_name) + steps = [ + self.api.tasks.create( + name=f"Pipeline step {i}", + project=project, + type="training", + system_tags=["pipeline"], + parent=task + ).id + for i in range(2) + ] + ids = [task, *steps] + res = self.api.tasks.get_all_ex(id=ids, search_hidden=True) + self.assertEqual(len(res.tasks), len(ids)) + + # stop + partial_ids = [task, steps[0]] + self.api.tasks.enqueue_many(ids=partial_ids) + res = self.api.tasks.get_all_ex(id=partial_ids, search_hidden=True) + self.assertTrue(t.stats == "in_progress" for t in res.tasks) + self.api.tasks.stop(task=task, include_pipeline_steps=True) + res = self.api.tasks.get_all_ex(id=ids, search_hidden=True) + self.assertTrue(t.stats == "created" for t in res.tasks) + + # archive/unarchive + self.api.tasks.archive(tasks=[task], include_pipeline_steps=True) + res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"]) + self.assertEqual(len(res.tasks), 0) + self.api.tasks.unarchive_many(ids=[task], include_pipeline_steps=True) + res = self.api.tasks.get_all_ex(id=ids, search_hidden=True, system_tags=["-archived"]) + self.assertEqual(len(res.tasks), len(ids)) + + # delete + self.api.tasks.delete(task=task, force=True, include_pipeline_steps=True) + res = self.api.tasks.get_all_ex(id=ids, search_hidden=True) + self.assertEqual(len(res.tasks), 0) + def test_delete_runs(self): queue = self.api.queues.get_default().id task_name = "pipelines test"