mirror of
https://github.com/clearml/clearml-server
synced 2025-02-01 19:33:44 +00:00
949 lines
30 KiB
Python
949 lines
30 KiB
Python
from copy import deepcopy
|
|
from datetime import datetime
|
|
from operator import attrgetter
|
|
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
|
|
|
|
import attr
|
|
import dpath
|
|
import mongoengine
|
|
from mongoengine import EmbeddedDocument, Q
|
|
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
|
from pymongo import UpdateOne
|
|
|
|
from apierrors import errors, APIError
|
|
from apimodels.base import UpdateResponse, IdResponse
|
|
from apimodels.tasks import (
|
|
StartedResponse,
|
|
ResetResponse,
|
|
PublishRequest,
|
|
PublishResponse,
|
|
CreateRequest,
|
|
UpdateRequest,
|
|
SetRequirementsRequest,
|
|
TaskRequest,
|
|
DeleteRequest,
|
|
PingRequest,
|
|
EnqueueRequest,
|
|
EnqueueResponse,
|
|
DequeueResponse,
|
|
CloneRequest,
|
|
AddOrUpdateArtifactsRequest,
|
|
AddOrUpdateArtifactsResponse,
|
|
)
|
|
from bll.event import EventBLL
|
|
from bll.organization import OrgBLL
|
|
from bll.queue import QueueBLL
|
|
from bll.task import (
|
|
TaskBLL,
|
|
ChangeStatusRequest,
|
|
update_project_time,
|
|
split_by,
|
|
ParameterKeyEscaper,
|
|
)
|
|
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
|
from bll.util import SetFieldsResolver
|
|
from database.errors import translate_errors_context
|
|
from database.model.model import Model
|
|
from database.model.task.output import Output
|
|
from database.model.task.task import (
|
|
Task,
|
|
TaskStatus,
|
|
Script,
|
|
DEFAULT_LAST_ITERATION,
|
|
Execution,
|
|
)
|
|
from database.utils import get_fields, parse_from_call
|
|
from service_repo import APICall, endpoint
|
|
from services.utils import conform_tag_fields, conform_output_tags
|
|
from timing_context import TimingContext
|
|
from utilities import safe_get
|
|
|
|
task_fields = set(Task.get_fields())
|
|
task_script_fields = set(get_fields(Script))
|
|
|
|
task_bll = TaskBLL()
|
|
event_bll = EventBLL()
|
|
queue_bll = QueueBLL()
|
|
org_bll = OrgBLL()
|
|
|
|
NonResponsiveTasksWatchdog.start()
|
|
|
|
|
|
def set_task_status_from_call(
|
|
request: UpdateRequest, company_id, new_status=None, **set_fields
|
|
) -> dict:
|
|
fields_resolver = SetFieldsResolver(set_fields)
|
|
task = TaskBLL.get_task_with_access(
|
|
request.task,
|
|
company_id=company_id,
|
|
only=tuple({"status", "project"} | fields_resolver.get_names()),
|
|
requires_write_access=True,
|
|
)
|
|
|
|
status_reason = request.status_reason
|
|
status_message = request.status_message
|
|
force = request.force
|
|
return ChangeStatusRequest(
|
|
task=task,
|
|
new_status=new_status or task.status,
|
|
status_reason=status_reason,
|
|
status_message=status_message,
|
|
force=force,
|
|
).execute(**fields_resolver.get_fields(task))
|
|
|
|
|
|
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
|
|
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task, company_id=company_id, allow_public=True
|
|
)
|
|
task_dict = task.to_proper_dict()
|
|
unprepare_from_saved(call, task_dict)
|
|
call.result.data = {"task": task_dict}
|
|
|
|
|
|
def escape_execution_parameters(call: APICall):
|
|
default_prefix = "execution.parameters."
|
|
|
|
def escape_paths(paths, prefix=default_prefix):
|
|
escaped_paths = []
|
|
for path in paths:
|
|
if path == prefix:
|
|
raise errors.bad_request.ValidationError(
|
|
"invalid task field", path=path
|
|
)
|
|
escaped_paths.append(
|
|
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
|
|
if path.startswith(prefix)
|
|
else path
|
|
)
|
|
return escaped_paths
|
|
|
|
projection = Task.get_projection(call.data)
|
|
if projection:
|
|
Task.set_projection(call.data, escape_paths(projection))
|
|
|
|
ordering = Task.get_ordering(call.data)
|
|
if ordering:
|
|
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
|
|
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
|
|
|
|
|
|
@endpoint("tasks.get_all_ex", required_fields=[])
|
|
def get_all_ex(call: APICall, company_id, _):
|
|
conform_tag_fields(call, call.data)
|
|
|
|
escape_execution_parameters(call)
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "task_get_all_ex"):
|
|
tasks = Task.get_many_with_join(
|
|
company=company_id,
|
|
query_dict=call.data,
|
|
allow_public=True, # required in case projection is requested for public dataset/versions
|
|
)
|
|
unprepare_from_saved(call, tasks)
|
|
call.result.data = {"tasks": tasks}
|
|
|
|
|
|
@endpoint("tasks.get_all", required_fields=[])
|
|
def get_all(call: APICall, company_id, _):
|
|
conform_tag_fields(call, call.data)
|
|
|
|
escape_execution_parameters(call)
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "task_get_all"):
|
|
tasks = Task.get_many(
|
|
company=company_id,
|
|
parameters=call.data,
|
|
query_dict=call.data,
|
|
allow_public=True, # required in case projection is requested for public dataset/versions
|
|
)
|
|
unprepare_from_saved(call, tasks)
|
|
call.result.data = {"tasks": tasks}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def stop(call: APICall, company_id, req_model: UpdateRequest):
|
|
"""
|
|
stop
|
|
:summary: Stop a running task. Requires task status 'in_progress' and
|
|
execution_progress 'running', or force=True.
|
|
Development task is stopped immediately. For a non-development task
|
|
only its status message is set to 'stopping'
|
|
|
|
"""
|
|
call.result.data_model = UpdateResponse(
|
|
**TaskBLL.stop_task(
|
|
task_id=req_model.task,
|
|
company_id=company_id,
|
|
user_name=call.identity.user_name,
|
|
status_reason=req_model.status_reason,
|
|
force=req_model.force,
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.stopped",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(
|
|
req_model,
|
|
company_id,
|
|
new_status=TaskStatus.stopped,
|
|
completed=datetime.utcnow(),
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.started",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=StartedResponse,
|
|
)
|
|
def started(call: APICall, company_id, req_model: UpdateRequest):
|
|
res = StartedResponse(
|
|
**set_task_status_from_call(
|
|
req_model,
|
|
company_id,
|
|
new_status=TaskStatus.in_progress,
|
|
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
|
)
|
|
)
|
|
res.started = res.updated
|
|
call.result.data_model = res
|
|
|
|
|
|
@endpoint(
|
|
"tasks.failed", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def failed(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.close", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def close(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed)
|
|
)
|
|
|
|
|
|
create_fields = {
|
|
"name": None,
|
|
"tags": list,
|
|
"system_tags": list,
|
|
"type": None,
|
|
"error": None,
|
|
"comment": None,
|
|
"parent": Task,
|
|
"project": None,
|
|
"input": None,
|
|
"output_dest": None,
|
|
"execution": None,
|
|
"script": None,
|
|
}
|
|
|
|
|
|
def prepare_for_save(call: APICall, fields: dict):
|
|
conform_tag_fields(call, fields, validate=True)
|
|
|
|
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
|
for field in task_script_fields:
|
|
try:
|
|
path = f"script/{field}"
|
|
value = dpath.get(fields, path)
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
dpath.set(fields, path, value)
|
|
except KeyError:
|
|
pass
|
|
|
|
parameters = safe_get(fields, "execution/parameters")
|
|
if parameters is not None:
|
|
# Escape keys to make them mongo-safe
|
|
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
|
|
dpath.set(fields, "execution/parameters", parameters)
|
|
|
|
return fields
|
|
|
|
|
|
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
|
if isinstance(tasks_data, dict):
|
|
tasks_data = [tasks_data]
|
|
|
|
conform_output_tags(call, tasks_data)
|
|
|
|
for task_data in tasks_data:
|
|
parameters = safe_get(task_data, "execution/parameters")
|
|
if parameters is not None:
|
|
# Escape keys to make them mongo-safe
|
|
parameters = {
|
|
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
|
|
}
|
|
dpath.set(task_data, "execution/parameters", parameters)
|
|
|
|
|
|
def prepare_create_fields(
|
|
call: APICall, valid_fields=None, output=None, previous_task: Task = None
|
|
):
|
|
valid_fields = valid_fields if valid_fields is not None else create_fields
|
|
t_fields = task_fields
|
|
t_fields.add("output_dest")
|
|
|
|
fields = parse_from_call(call.data, valid_fields, t_fields)
|
|
|
|
# Move output_dest to output.destination
|
|
output_dest = fields.get("output_dest")
|
|
if output_dest is not None:
|
|
fields.pop("output_dest")
|
|
if output:
|
|
output.destination = output_dest
|
|
else:
|
|
output = Output(destination=output_dest)
|
|
fields["output"] = output
|
|
|
|
return prepare_for_save(call, fields)
|
|
|
|
|
|
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
|
with translate_errors_context(
|
|
field_does_not_exist_cls=errors.bad_request.ValidationError
|
|
), TimingContext("code", "parse_call"):
|
|
fields = prepare_create_fields(call, **kwargs)
|
|
task = task_bll.create(call, fields)
|
|
|
|
with TimingContext("code", "validate"):
|
|
task_bll.validate(task)
|
|
|
|
return task, fields
|
|
|
|
|
|
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
|
def validate(call: APICall, company_id, req_model: CreateRequest):
|
|
_validate_and_get_task_from_call(call)
|
|
|
|
|
|
def _update_org_tags(company, fields: dict):
|
|
org_bll.update_org_tags(
|
|
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
|
|
)
|
|
def create(call: APICall, company_id, req_model: CreateRequest):
|
|
task, fields = _validate_and_get_task_from_call(call)
|
|
|
|
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
|
task.save()
|
|
_update_org_tags(company_id, fields)
|
|
update_project_time(task.project)
|
|
|
|
call.result.data_model = IdResponse(id=task.id)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
|
)
|
|
def clone_task(call: APICall, company_id, request: CloneRequest):
|
|
task = task_bll.clone_task(
|
|
company_id=company_id,
|
|
user_id=call.identity.user,
|
|
task_id=request.task,
|
|
name=request.new_task_name,
|
|
comment=request.new_task_comment,
|
|
parent=request.new_task_parent,
|
|
project=request.new_task_project,
|
|
tags=request.new_task_tags,
|
|
system_tags=request.new_task_system_tags,
|
|
execution_overrides=request.execution_overrides,
|
|
validate_references=request.validate_references,
|
|
)
|
|
call.result.data_model = IdResponse(id=task.id)
|
|
|
|
|
|
def prepare_update_fields(call: APICall, task, call_data):
|
|
valid_fields = deepcopy(task.__class__.user_set_allowed())
|
|
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
|
update_fields["output__error"] = None
|
|
t_fields = task_fields
|
|
t_fields.add("output__error")
|
|
fields = parse_from_call(call_data, update_fields, t_fields)
|
|
return prepare_for_save(call, fields), valid_fields
|
|
|
|
|
|
@endpoint(
|
|
"tasks.update", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def update(call: APICall, company_id, req_model: UpdateRequest):
|
|
task_id = req_model.task
|
|
|
|
with translate_errors_context():
|
|
task = Task.get_for_writing(id=task_id, company=company_id, _only=["id"])
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
|
|
|
if not partial_update_dict:
|
|
return UpdateResponse(updated=0)
|
|
|
|
updated_count, updated_fields = Task.safe_update(
|
|
company_id=company_id,
|
|
id=task_id,
|
|
partial_update_dict=partial_update_dict,
|
|
injected_update=dict(last_update=datetime.utcnow()),
|
|
)
|
|
if updated_count:
|
|
_update_org_tags(company_id, updated_fields)
|
|
update_project_time(updated_fields.get("project"))
|
|
unprepare_from_saved(call, updated_fields)
|
|
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.set_requirements",
|
|
request_data_model=SetRequirementsRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
|
|
requirements = req_model.requirements
|
|
with translate_errors_context():
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task,
|
|
company_id=company_id,
|
|
only=("status", "script"),
|
|
requires_write_access=True,
|
|
)
|
|
if not task.script:
|
|
raise errors.bad_request.MissingTaskFields(
|
|
"Task has no script field", task=task.id
|
|
)
|
|
res = task.update(
|
|
script__requirements=requirements, last_update=datetime.utcnow()
|
|
)
|
|
call.result.data_model = UpdateResponse(updated=res)
|
|
if res:
|
|
call.result.data_model.fields = {"script.requirements": requirements}
|
|
|
|
|
|
@endpoint("tasks.update_batch")
|
|
def update_batch(call: APICall, company_id, _):
|
|
items = call.batched_data
|
|
if items is None:
|
|
raise errors.bad_request.BatchContainsNoItems()
|
|
|
|
with translate_errors_context():
|
|
items = {i["task"]: i for i in items}
|
|
tasks = {
|
|
t.id: t
|
|
for t in Task.get_many_for_writing(
|
|
company=company_id, query=Q(id__in=list(items))
|
|
)
|
|
}
|
|
|
|
if len(tasks) < len(items):
|
|
missing = tuple(set(items).difference(tasks))
|
|
raise errors.bad_request.InvalidTaskId(ids=missing)
|
|
|
|
now = datetime.utcnow()
|
|
|
|
bulk_ops = []
|
|
for id, data in items.items():
|
|
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
|
|
partial_update_dict = Task.get_safe_update_dict(fields)
|
|
if not partial_update_dict:
|
|
continue
|
|
partial_update_dict.update(last_update=now)
|
|
update_op = UpdateOne(
|
|
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
|
)
|
|
bulk_ops.append(update_op)
|
|
|
|
updated = 0
|
|
if bulk_ops:
|
|
res = Task._get_collection().bulk_write(bulk_ops)
|
|
updated = res.modified_count
|
|
if updated:
|
|
org_bll.update_org_tags(company_id, reset=True)
|
|
call.result.data = {"updated": updated}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.edit", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def edit(call: APICall, company_id, req_model: UpdateRequest):
|
|
task_id = req_model.task
|
|
force = req_model.force
|
|
|
|
with translate_errors_context():
|
|
task = Task.get_for_writing(id=task_id, company=company_id)
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
if not force and task.status != TaskStatus.created:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
expected=TaskStatus.created, status=task.status
|
|
)
|
|
|
|
edit_fields = create_fields.copy()
|
|
edit_fields.update(dict(status=None))
|
|
|
|
with translate_errors_context(
|
|
field_does_not_exist_cls=errors.bad_request.ValidationError
|
|
), TimingContext("code", "parse_and_validate"):
|
|
fields = prepare_create_fields(
|
|
call, valid_fields=edit_fields, output=task.output, previous_task=task
|
|
)
|
|
|
|
for key in fields:
|
|
field = getattr(task, key, None)
|
|
value = fields[key]
|
|
if (
|
|
field
|
|
and isinstance(value, dict)
|
|
and isinstance(field, EmbeddedDocument)
|
|
):
|
|
d = field.to_mongo(use_db_field=False).to_dict()
|
|
d.update(value)
|
|
fields[key] = d
|
|
|
|
task_bll.validate(task_bll.create(call, fields))
|
|
|
|
# make sure field names do not end in mongoengine comparison operators
|
|
fixed_fields = {
|
|
(k if k not in COMPARISON_OPERATORS else "%s__" % k): v
|
|
for k, v in fields.items()
|
|
}
|
|
if fixed_fields:
|
|
now = datetime.utcnow()
|
|
fields.update(last_update=now)
|
|
fixed_fields.update(last_update=now)
|
|
updated = task.update(upsert=False, **fixed_fields)
|
|
if updated:
|
|
_update_org_tags(company_id, fixed_fields)
|
|
update_project_time(fields.get("project"))
|
|
unprepare_from_saved(call, fields)
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
else:
|
|
call.result.data_model = UpdateResponse(updated=0)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.enqueue",
|
|
request_data_model=EnqueueRequest,
|
|
response_data_model=EnqueueResponse,
|
|
)
|
|
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
|
|
task_id = req_model.task
|
|
queue_id = req_model.queue
|
|
status_message = req_model.status_message
|
|
status_reason = req_model.status_reason
|
|
|
|
if not queue_id:
|
|
# try to get default queue
|
|
queue_id = queue_bll.get_default(company_id).id
|
|
|
|
with translate_errors_context():
|
|
query = dict(id=task_id, company=company_id)
|
|
task = Task.get_for_writing(
|
|
_only=("type", "script", "execution", "status", "project", "id"), **query
|
|
)
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(**query)
|
|
|
|
res = EnqueueResponse(
|
|
**ChangeStatusRequest(
|
|
task=task,
|
|
new_status=TaskStatus.queued,
|
|
status_reason=status_reason,
|
|
status_message=status_message,
|
|
allow_same_state_transition=False,
|
|
).execute()
|
|
)
|
|
|
|
try:
|
|
queue_bll.add_task(
|
|
company_id=company_id, queue_id=queue_id, task_id=task.id
|
|
)
|
|
except Exception:
|
|
# failed enqueueing, revert to previous state
|
|
ChangeStatusRequest(
|
|
task=task,
|
|
current_status_override=TaskStatus.queued,
|
|
new_status=task.status,
|
|
force=True,
|
|
status_reason="failed enqueueing",
|
|
).execute()
|
|
raise
|
|
|
|
# set the current queue ID in the task
|
|
if task.execution:
|
|
Task.objects(**query).update(execution__queue=queue_id, multi=False)
|
|
else:
|
|
Task.objects(**query).update(
|
|
execution=Execution(queue=queue_id), multi=False
|
|
)
|
|
|
|
res.queued = 1
|
|
res.fields.update(**{"execution.queue": queue_id})
|
|
|
|
call.result.data_model = res
|
|
|
|
|
|
@endpoint(
|
|
"tasks.dequeue",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=DequeueResponse,
|
|
)
|
|
def dequeue(call: APICall, company_id, req_model: UpdateRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task,
|
|
company_id=company_id,
|
|
only=("id", "execution", "status", "project"),
|
|
requires_write_access=True,
|
|
)
|
|
if task.status not in (TaskStatus.queued,):
|
|
raise errors.bad_request.InvalidTaskId(
|
|
status=task.status, expected=TaskStatus.queued
|
|
)
|
|
|
|
_dequeue(task, company_id)
|
|
|
|
status_message = req_model.status_message
|
|
status_reason = req_model.status_reason
|
|
res = DequeueResponse(
|
|
**ChangeStatusRequest(
|
|
task=task,
|
|
new_status=TaskStatus.created,
|
|
status_reason=status_reason,
|
|
status_message=status_message,
|
|
).execute(unset__execution__queue=1)
|
|
)
|
|
res.dequeued = 1
|
|
|
|
call.result.data_model = res
|
|
|
|
|
|
def _dequeue(task: Task, company_id: str, silent_fail=False):
|
|
"""
|
|
Dequeue the task from the queue
|
|
:param task: task to dequeue
|
|
:param silent_fail: do not throw exceptions. APIError is still thrown
|
|
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
|
|
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
|
|
:return: the result of queues.remove_task call. None in case of silent failure
|
|
"""
|
|
if not task.execution or not task.execution.queue:
|
|
if silent_fail:
|
|
return
|
|
raise errors.bad_request.MissingRequiredFields(
|
|
"task has no queue value", field="execution.queue"
|
|
)
|
|
|
|
return {
|
|
"removed": queue_bll.remove_task(
|
|
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
|
|
)
|
|
def reset(call: APICall, company_id, req_model: UpdateRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task, company_id=company_id, requires_write_access=True
|
|
)
|
|
|
|
force = req_model.force
|
|
|
|
if not force and task.status == TaskStatus.published:
|
|
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
|
|
|
|
api_results = {}
|
|
updates = {}
|
|
|
|
try:
|
|
dequeued = _dequeue(task, company_id, silent_fail=True)
|
|
except APIError:
|
|
# dequeue may fail if the task was not enqueued
|
|
pass
|
|
else:
|
|
if dequeued:
|
|
api_results.update(dequeued=dequeued)
|
|
updates.update(unset__execution__queue=1)
|
|
|
|
cleaned_up = cleanup_task(task, force)
|
|
api_results.update(attr.asdict(cleaned_up))
|
|
|
|
updates.update(
|
|
set__last_iteration=DEFAULT_LAST_ITERATION,
|
|
set__last_metrics={},
|
|
unset__output__result=1,
|
|
unset__output__model=1,
|
|
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
|
)
|
|
|
|
res = ResetResponse(
|
|
**ChangeStatusRequest(
|
|
task=task,
|
|
new_status=TaskStatus.created,
|
|
force=force,
|
|
status_reason="reset",
|
|
status_message="reset",
|
|
).execute(started=None, completed=None, published=None, **updates)
|
|
)
|
|
|
|
for key, value in api_results.items():
|
|
setattr(res, key, value)
|
|
|
|
call.result.data_model = res
|
|
|
|
|
|
class DocumentGroup(list):
|
|
"""
|
|
Operate on a list of documents as if they were a query result
|
|
"""
|
|
|
|
def __init__(self, document_type, documents):
|
|
super(DocumentGroup, self).__init__(documents)
|
|
self.type = document_type
|
|
|
|
def objects(self, *args, **kwargs):
|
|
return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs)
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class TaskOutputs(object):
|
|
"""
|
|
Split task outputs of the same type by the ready state
|
|
"""
|
|
|
|
published = None # type: DocumentGroup
|
|
draft = None # type: DocumentGroup
|
|
|
|
def __init__(self, is_published, document_type, children):
|
|
# type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> ()
|
|
"""
|
|
:param is_published: predicate returning whether items is considered published
|
|
:param document_type: type of output
|
|
:param children: output documents
|
|
"""
|
|
self.published, self.draft = map(
|
|
lambda x: DocumentGroup(document_type, x), split_by(is_published, children)
|
|
)
|
|
|
|
|
|
@attr.s
|
|
class CleanupResult(object):
|
|
"""
|
|
Counts of objects modified in task cleanup operation
|
|
"""
|
|
|
|
updated_children = attr.ib(type=int)
|
|
updated_models = attr.ib(type=int)
|
|
deleted_models = attr.ib(type=int)
|
|
|
|
|
|
def cleanup_task(task: Task, force: bool = False):
|
|
"""
|
|
Validate task deletion and delete/modify all its output.
|
|
:param task: task object
|
|
:param force: whether to delete task with published outputs
|
|
:return: count of delete and modified items
|
|
"""
|
|
models, child_tasks = get_outputs_for_deletion(task, force)
|
|
deleted_task_id = trash_task_id(task.id)
|
|
if child_tasks:
|
|
with TimingContext("mongo", "update_task_children"):
|
|
updated_children = child_tasks.update(parent=deleted_task_id)
|
|
else:
|
|
updated_children = 0
|
|
|
|
if models.draft:
|
|
with TimingContext("mongo", "delete_models"):
|
|
deleted_models = models.draft.objects().delete()
|
|
else:
|
|
deleted_models = 0
|
|
|
|
if models.published:
|
|
with TimingContext("mongo", "update_task_models"):
|
|
updated_models = models.published.objects().update(task=deleted_task_id)
|
|
else:
|
|
updated_models = 0
|
|
|
|
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
|
|
|
return CleanupResult(
|
|
deleted_models=deleted_models,
|
|
updated_children=updated_children,
|
|
updated_models=updated_models,
|
|
)
|
|
|
|
|
|
def get_outputs_for_deletion(task, force=False):
|
|
with TimingContext("mongo", "get_task_models"):
|
|
models = TaskOutputs(
|
|
attrgetter("ready"),
|
|
Model,
|
|
Model.objects(task=task.id).only("id", "task", "ready"),
|
|
)
|
|
if not force and models.published:
|
|
raise errors.bad_request.TaskCannotBeDeleted(
|
|
"has output models, use force=True",
|
|
task=task.id,
|
|
models=len(models.published),
|
|
)
|
|
|
|
if task.output.model:
|
|
output_model = get_output_model(task, force)
|
|
if output_model:
|
|
if output_model.ready:
|
|
models.published.append(output_model)
|
|
else:
|
|
models.draft.append(output_model)
|
|
|
|
if models.draft:
|
|
with TimingContext("mongo", "get_execution_models"):
|
|
model_ids = [m.id for m in models.draft]
|
|
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
|
"id", "execution.model"
|
|
)
|
|
busy_models = [t.execution.model for t in dependent_tasks]
|
|
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
|
|
|
|
with TimingContext("mongo", "get_task_children"):
|
|
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
|
|
published_tasks = [
|
|
task for task in tasks if task.status == TaskStatus.published
|
|
]
|
|
if not force and published_tasks:
|
|
raise errors.bad_request.TaskCannotBeDeleted(
|
|
"has children, use force=True", task=task.id, children=published_tasks
|
|
)
|
|
return models, tasks
|
|
|
|
|
|
def get_output_model(task, force=False):
|
|
with TimingContext("mongo", "get_task_output_model"):
|
|
output_model = Model.objects(id=task.output.model).first()
|
|
if output_model and output_model.ready and not force:
|
|
raise errors.bad_request.TaskCannotBeDeleted(
|
|
"has output model, use force=True", task=task.id, model=task.output.model
|
|
)
|
|
return output_model
|
|
|
|
|
|
def trash_task_id(task_id):
|
|
return "__DELETED__{}".format(task_id)
|
|
|
|
|
|
@endpoint("tasks.delete", request_data_model=DeleteRequest)
|
|
def delete(call: APICall, company_id, req_model: DeleteRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task, company_id=company_id, requires_write_access=True
|
|
)
|
|
|
|
move_to_trash = req_model.move_to_trash
|
|
force = req_model.force
|
|
|
|
if task.status != TaskStatus.created and not force:
|
|
raise errors.bad_request.TaskCannotBeDeleted(
|
|
"due to status, use force=True",
|
|
task=task.id,
|
|
expected=TaskStatus.created,
|
|
current=task.status,
|
|
)
|
|
|
|
with translate_errors_context():
|
|
result = cleanup_task(task, force)
|
|
|
|
if move_to_trash:
|
|
collection_name = task._get_collection_name()
|
|
archived_collection = "{}__trash".format(collection_name)
|
|
task.switch_collection(archived_collection)
|
|
try:
|
|
# A simple save() won't do due to mongoengine caching (nothing will be saved), so we have to force
|
|
# an insert. However, if for some reason such an ID exists, let's make sure we'll keep going.
|
|
with TimingContext("mongo", "save_task"):
|
|
task.save(force_insert=True)
|
|
except Exception:
|
|
pass
|
|
task.switch_collection(collection_name)
|
|
|
|
task.delete()
|
|
org_bll.update_org_tags(company_id, reset=True)
|
|
call.result.data = dict(deleted=True, **attr.asdict(result))
|
|
|
|
|
|
@endpoint(
|
|
"tasks.publish",
|
|
request_data_model=PublishRequest,
|
|
response_data_model=PublishResponse,
|
|
)
|
|
def publish(call: APICall, company_id, req_model: PublishRequest):
|
|
call.result.data_model = PublishResponse(
|
|
**TaskBLL.publish_task(
|
|
task_id=req_model.task,
|
|
company_id=company_id,
|
|
publish_model=req_model.publish_model,
|
|
force=req_model.force,
|
|
status_reason=req_model.status_reason,
|
|
status_message=req_model.status_message,
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.completed",
|
|
min_version="2.2",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def completed(call: APICall, company_id, request: PublishRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(
|
|
request,
|
|
company_id,
|
|
new_status=TaskStatus.completed,
|
|
completed=datetime.utcnow(),
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint("tasks.ping", request_data_model=PingRequest)
|
|
def ping(_, company_id, request: PingRequest):
|
|
TaskBLL.set_last_update(
|
|
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.add_or_update_artifacts",
|
|
min_version="2.6",
|
|
request_data_model=AddOrUpdateArtifactsRequest,
|
|
response_data_model=AddOrUpdateArtifactsResponse,
|
|
)
|
|
def add_or_update_artifacts(
|
|
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
|
):
|
|
added, updated = TaskBLL.add_or_update_artifacts(
|
|
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
|
)
|
|
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|