Support automatic handling of pipeline steps if a pipeline controller task ID was passed to one of the tasks endpoints

This commit is contained in:
allegroai
2024-06-20 17:52:46 +03:00
parent cdc668e3c8
commit 2e19a18ee4
4 changed files with 236 additions and 115 deletions

View File

@@ -55,7 +55,6 @@ from apiserver.apimodels.tasks import (
ResetManyRequest,
DeleteManyRequest,
PublishManyRequest,
TaskBatchRequest,
EnqueueManyResponse,
EnqueueBatchItem,
DequeueBatchItem,
@@ -68,6 +67,9 @@ from apiserver.apimodels.tasks import (
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
StopRequest,
UnarchiveManyRequest,
ArchiveManyRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@@ -98,7 +100,6 @@ from apiserver.bll.task.task_operations import (
delete_task,
publish_task,
unarchive_task,
move_tasks_to_trash,
)
from apiserver.bll.task.utils import (
update_task,
@@ -295,9 +296,9 @@ def get_types(call: APICall, company_id, request: GetTypesRequest):
@endpoint(
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
"tasks.stop", response_data_model=UpdateResponse
)
def stop(call: APICall, company_id, req_model: UpdateRequest):
def stop(call: APICall, company_id, request: StopRequest):
"""
stop
:summary: Stop a running task. Requires task status 'in_progress' and
@@ -308,12 +309,13 @@ def stop(call: APICall, company_id, req_model: UpdateRequest):
"""
call.result.data_model = UpdateResponse(
**stop_task(
task_id=req_model.task,
task_id=request.task,
company_id=company_id,
identity=call.identity,
user_name=call.identity.user_name,
status_reason=req_model.status_reason,
force=req_model.force,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
)
)
@@ -332,6 +334,7 @@ def stop_many(call: APICall, company_id, request: StopManyRequest):
user_name=call.identity.user_name,
status_reason=request.status_reason,
force=request.force,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -531,7 +534,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest):
def validate(call: APICall, _, __: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(deleted_prefix):
call.data.pop("parent")
@@ -555,7 +558,7 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
def create(call: APICall, company_id, _: CreateRequest):
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context():
@@ -1087,6 +1090,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
"project",
"system_tags",
"enqueue_status",
"type",
),
)
for task in tasks:
@@ -1096,6 +1100,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
task=task,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
)
call.result.data_model = ArchiveResponse(archived=archived)
@@ -1103,10 +1108,9 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
@endpoint(
"tasks.archive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def archive_many(call: APICall, company_id, request: TaskBatchRequest):
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
archive_task,
@@ -1114,6 +1118,7 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1125,10 +1130,9 @@ def archive_many(call: APICall, company_id, request: TaskBatchRequest):
@endpoint(
"tasks.unarchive_many",
request_data_model=TaskBatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
def unarchive_many(call: APICall, company_id, request: UnarchiveManyRequest):
results, failures = run_batch_operation(
func=partial(
unarchive_task,
@@ -1136,6 +1140,7 @@ def unarchive_many(call: APICall, company_id, request: TaskBatchRequest):
identity=call.identity,
status_message=request.status_message,
status_reason=request.status_reason,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
@@ -1160,10 +1165,9 @@ def delete(call: APICall, company_id, request: DeleteRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
)
if deleted:
if request.move_to_trash:
move_tasks_to_trash([request.task])
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
@@ -1182,15 +1186,12 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
status_message=request.status_message,
status_reason=request.status_reason,
delete_external_artifacts=request.delete_external_artifacts,
include_pipeline_steps=request.include_pipeline_steps,
),
ids=request.ids,
)
if results:
if request.move_to_trash:
task_ids = set(task.id for _, (_, task, _) in results)
if task_ids:
move_tasks_to_trash(list(task_ids))
projects = set(task.project for _, (_, task, _) in results)
_reset_cached_tags(company_id, projects=list(projects))