mirror of
https://github.com/clearml/clearml-server
synced 2025-02-01 11:26:43 +00:00
1255 lines
40 KiB
Python
1255 lines
40 KiB
Python
from copy import deepcopy
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from typing import Sequence, Union, Tuple, Set
|
|
|
|
import attr
|
|
import dpath
|
|
from mongoengine import EmbeddedDocument, Q
|
|
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
|
from pymongo import UpdateOne
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apierrors.errors.bad_request import InvalidTaskId
|
|
from apiserver.apimodels.base import (
|
|
UpdateResponse,
|
|
IdResponse,
|
|
MakePublicRequest,
|
|
MoveRequest,
|
|
)
|
|
from apiserver.apimodels.tasks import (
|
|
StartedResponse,
|
|
ResetResponse,
|
|
PublishRequest,
|
|
PublishResponse,
|
|
CreateRequest,
|
|
UpdateRequest,
|
|
SetRequirementsRequest,
|
|
TaskRequest,
|
|
DeleteRequest,
|
|
PingRequest,
|
|
EnqueueRequest,
|
|
EnqueueResponse,
|
|
DequeueResponse,
|
|
CloneRequest,
|
|
AddOrUpdateArtifactsRequest,
|
|
GetTypesRequest,
|
|
ResetRequest,
|
|
GetHyperParamsRequest,
|
|
EditHyperParamsRequest,
|
|
DeleteHyperParamsRequest,
|
|
GetConfigurationsRequest,
|
|
EditConfigurationRequest,
|
|
DeleteConfigurationRequest,
|
|
GetConfigurationNamesRequest,
|
|
DeleteArtifactsRequest,
|
|
ArchiveResponse,
|
|
ArchiveRequest,
|
|
AddUpdateModelRequest,
|
|
DeleteModelsRequest,
|
|
ModelItemType,
|
|
StopManyResponse,
|
|
StopManyRequest,
|
|
EnqueueManyRequest,
|
|
EnqueueManyResponse,
|
|
ResetManyRequest,
|
|
ArchiveManyRequest,
|
|
ArchiveManyResponse,
|
|
DeleteManyRequest,
|
|
PublishManyRequest,
|
|
)
|
|
from apiserver.bll.event import EventBLL
|
|
from apiserver.bll.model import ModelBLL
|
|
from apiserver.bll.organization import OrgBLL, Tags
|
|
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
|
from apiserver.bll.queue import QueueBLL
|
|
from apiserver.bll.task import (
|
|
TaskBLL,
|
|
ChangeStatusRequest,
|
|
update_project_time,
|
|
)
|
|
from apiserver.bll.task.artifacts import (
|
|
artifacts_prepare_for_save,
|
|
artifacts_unprepare_from_saved,
|
|
Artifacts,
|
|
)
|
|
from apiserver.bll.task.hyperparams import HyperParams
|
|
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
|
from apiserver.bll.task.param_utils import (
|
|
params_prepare_for_save,
|
|
params_unprepare_from_saved,
|
|
escape_paths,
|
|
)
|
|
from apiserver.bll.task.task_cleanup import CleanupResult
|
|
from apiserver.bll.task.task_operations import (
|
|
stop_task,
|
|
enqueue_task,
|
|
reset_task,
|
|
archive_task,
|
|
delete_task,
|
|
publish_task,
|
|
)
|
|
from apiserver.bll.task.utils import update_task, deleted_prefix, get_task_for_update
|
|
from apiserver.bll.util import SetFieldsResolver, run_batch_operation
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.database.model.task.output import Output
|
|
from apiserver.database.model.task.task import (
|
|
Task,
|
|
TaskStatus,
|
|
Script,
|
|
ModelItem,
|
|
)
|
|
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
|
from apiserver.service_repo import APICall, endpoint
|
|
from apiserver.services.utils import (
|
|
conform_tag_fields,
|
|
conform_output_tags,
|
|
ModelsBackwardsCompatibility,
|
|
DockerCmdBackwardsCompatibility,
|
|
escape_dict_field,
|
|
unescape_dict_field,
|
|
)
|
|
from apiserver.timing_context import TimingContext
|
|
from apiserver.utilities.partial_version import PartialVersion
|
|
|
|
task_fields = set(Task.get_fields())
|
|
task_script_stripped_fields = set(
|
|
[f for f, v in get_fields_attr(Script, "strip").items() if v]
|
|
)
|
|
|
|
task_bll = TaskBLL()
|
|
event_bll = EventBLL()
|
|
queue_bll = QueueBLL()
|
|
org_bll = OrgBLL()
|
|
project_bll = ProjectBLL()
|
|
|
|
NonResponsiveTasksWatchdog.start()
|
|
|
|
|
|
def set_task_status_from_call(
|
|
request: UpdateRequest, company_id, new_status=None, **set_fields
|
|
) -> dict:
|
|
fields_resolver = SetFieldsResolver(set_fields)
|
|
task = TaskBLL.get_task_with_access(
|
|
request.task,
|
|
company_id=company_id,
|
|
only=tuple(
|
|
{"status", "project", "started", "duration"} | fields_resolver.get_names()
|
|
),
|
|
requires_write_access=True,
|
|
)
|
|
|
|
if "duration" not in fields_resolver.get_names():
|
|
if new_status == Task.started:
|
|
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
|
|
elif new_status in (
|
|
TaskStatus.completed,
|
|
TaskStatus.failed,
|
|
TaskStatus.stopped,
|
|
):
|
|
fields_resolver.add_fields(
|
|
duration=int((task.started - datetime.utcnow()).total_seconds())
|
|
if task.started
|
|
else 0
|
|
)
|
|
|
|
status_reason = request.status_reason
|
|
status_message = request.status_message
|
|
force = request.force
|
|
return ChangeStatusRequest(
|
|
task=task,
|
|
new_status=new_status or task.status,
|
|
status_reason=status_reason,
|
|
status_message=status_message,
|
|
force=force,
|
|
).execute(**fields_resolver.get_fields(task))
|
|
|
|
|
|
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
|
|
def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task, company_id=company_id, allow_public=True
|
|
)
|
|
task_dict = task.to_proper_dict()
|
|
unprepare_from_saved(call, task_dict)
|
|
call.result.data = {"task": task_dict}
|
|
|
|
|
|
def escape_execution_parameters(call: APICall) -> dict:
|
|
if not call.data:
|
|
return call.data
|
|
|
|
keys = list(call.data)
|
|
call_data = {
|
|
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
|
|
}
|
|
|
|
projection = Task.get_projection(call_data)
|
|
if projection:
|
|
Task.set_projection(call_data, escape_paths(projection))
|
|
|
|
ordering = Task.get_ordering(call_data)
|
|
if ordering:
|
|
Task.set_ordering(call_data, escape_paths(ordering))
|
|
|
|
return call_data
|
|
|
|
|
|
def _process_include_subprojects(call_data: dict):
|
|
include_subprojects = call_data.pop("include_subprojects", False)
|
|
project_ids = call_data.get("project")
|
|
if not project_ids or not include_subprojects:
|
|
return
|
|
|
|
if not isinstance(project_ids, list):
|
|
project_ids = [project_ids]
|
|
call_data["project"] = project_ids_with_children(project_ids)
|
|
|
|
|
|
@endpoint("tasks.get_all_ex", required_fields=[])
|
|
def get_all_ex(call: APICall, company_id, _):
|
|
conform_tag_fields(call, call.data)
|
|
|
|
call_data = escape_execution_parameters(call)
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "task_get_all_ex"):
|
|
_process_include_subprojects(call_data)
|
|
tasks = Task.get_many_with_join(
|
|
company=company_id, query_dict=call_data, allow_public=True,
|
|
)
|
|
unprepare_from_saved(call, tasks)
|
|
call.result.data = {"tasks": tasks}
|
|
|
|
|
|
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
|
def get_by_id_ex(call: APICall, company_id, _):
|
|
conform_tag_fields(call, call.data)
|
|
|
|
call_data = escape_execution_parameters(call)
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "task_get_by_id_ex"):
|
|
tasks = Task.get_many_with_join(
|
|
company=company_id, query_dict=call_data, allow_public=True,
|
|
)
|
|
|
|
unprepare_from_saved(call, tasks)
|
|
call.result.data = {"tasks": tasks}
|
|
|
|
|
|
@endpoint("tasks.get_all", required_fields=[])
|
|
def get_all(call: APICall, company_id, _):
|
|
conform_tag_fields(call, call.data)
|
|
|
|
call_data = escape_execution_parameters(call)
|
|
|
|
with translate_errors_context():
|
|
with TimingContext("mongo", "task_get_all"):
|
|
tasks = Task.get_many(
|
|
company=company_id,
|
|
parameters=call_data,
|
|
query_dict=call_data,
|
|
allow_public=True,
|
|
)
|
|
unprepare_from_saved(call, tasks)
|
|
call.result.data = {"tasks": tasks}
|
|
|
|
|
|
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
|
def get_types(call: APICall, company_id, request: GetTypesRequest):
|
|
call.result.data = {
|
|
"types": list(
|
|
project_bll.get_task_types(company_id, project_ids=request.projects)
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.stop", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def stop(call: APICall, company_id, req_model: UpdateRequest):
|
|
"""
|
|
stop
|
|
:summary: Stop a running task. Requires task status 'in_progress' and
|
|
execution_progress 'running', or force=True.
|
|
Development task is stopped immediately. For a non-development task
|
|
only its status message is set to 'stopping'
|
|
|
|
"""
|
|
call.result.data_model = UpdateResponse(
|
|
**stop_task(
|
|
task_id=req_model.task,
|
|
company_id=company_id,
|
|
user_name=call.identity.user_name,
|
|
status_reason=req_model.status_reason,
|
|
force=req_model.force,
|
|
)
|
|
)
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class StopRes:
|
|
stopped: int = 0
|
|
|
|
def __add__(self, other: dict):
|
|
return StopRes(stopped=self.stopped + 1)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.stop_many",
|
|
request_data_model=StopManyRequest,
|
|
response_data_model=StopManyResponse,
|
|
)
|
|
def stop_many(call: APICall, company_id, request: StopManyRequest):
|
|
res, failures = run_batch_operation(
|
|
func=partial(
|
|
stop_task,
|
|
company_id=company_id,
|
|
user_name=call.identity.user_name,
|
|
status_reason=request.status_reason,
|
|
force=request.force,
|
|
),
|
|
ids=request.ids,
|
|
init_res=StopRes(),
|
|
)
|
|
call.result.data_model = StopManyResponse(stopped=res.stopped, failures=failures)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.stopped",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def stopped(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(
|
|
req_model,
|
|
company_id,
|
|
new_status=TaskStatus.stopped,
|
|
completed=datetime.utcnow(),
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.started",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=StartedResponse,
|
|
)
|
|
def started(call: APICall, company_id, req_model: UpdateRequest):
|
|
res = StartedResponse(
|
|
**set_task_status_from_call(
|
|
req_model,
|
|
company_id,
|
|
new_status=TaskStatus.in_progress,
|
|
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
|
|
)
|
|
)
|
|
res.started = res.updated
|
|
call.result.data_model = res
|
|
|
|
|
|
@endpoint(
|
|
"tasks.failed", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def failed(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.failed)
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.close", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def close(call: APICall, company_id, req_model: UpdateRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(req_model, company_id, new_status=TaskStatus.closed)
|
|
)
|
|
|
|
|
|
create_fields = {
|
|
"name": None,
|
|
"tags": list,
|
|
"system_tags": list,
|
|
"type": None,
|
|
"error": None,
|
|
"comment": None,
|
|
"parent": Task,
|
|
"project": None,
|
|
"input": None,
|
|
"models": None,
|
|
"container": None,
|
|
"output_dest": None,
|
|
"execution": None,
|
|
"hyperparams": None,
|
|
"configuration": None,
|
|
"script": None,
|
|
}
|
|
|
|
dict_fields_paths = [("execution", "model_labels"), "container"]
|
|
|
|
|
|
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
|
conform_tag_fields(call, fields, validate=True)
|
|
params_prepare_for_save(fields, previous_task=previous_task)
|
|
artifacts_prepare_for_save(fields)
|
|
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
|
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
|
|
for path in dict_fields_paths:
|
|
escape_dict_field(fields, path)
|
|
|
|
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
|
for field in task_script_stripped_fields:
|
|
try:
|
|
path = f"script/{field}"
|
|
value = dpath.get(fields, path)
|
|
if isinstance(value, str):
|
|
value = value.strip()
|
|
dpath.set(fields, path, value)
|
|
except KeyError:
|
|
pass
|
|
|
|
return fields
|
|
|
|
|
|
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
|
if isinstance(tasks_data, dict):
|
|
tasks_data = [tasks_data]
|
|
|
|
conform_output_tags(call, tasks_data)
|
|
|
|
for data in tasks_data:
|
|
for path in dict_fields_paths:
|
|
unescape_dict_field(data, path)
|
|
|
|
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
|
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
|
|
|
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
|
|
|
for data in tasks_data:
|
|
params_unprepare_from_saved(
|
|
fields=data, copy_to_legacy=need_legacy_params,
|
|
)
|
|
artifacts_unprepare_from_saved(fields=data)
|
|
|
|
|
|
def prepare_create_fields(
|
|
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
|
|
):
|
|
valid_fields = valid_fields if valid_fields is not None else create_fields
|
|
t_fields = task_fields
|
|
t_fields.add("output_dest")
|
|
|
|
fields = parse_from_call(call.data, valid_fields, t_fields)
|
|
|
|
# Move output_dest to output.destination
|
|
output_dest = fields.get("output_dest")
|
|
if output_dest is not None:
|
|
fields.pop("output_dest")
|
|
if output:
|
|
output.destination = output_dest
|
|
else:
|
|
output = Output(destination=output_dest)
|
|
fields["output"] = output
|
|
|
|
# Add models updated time
|
|
models = fields.get("models")
|
|
if models:
|
|
now = datetime.utcnow()
|
|
for field in ("input", "output"):
|
|
field_models = models.get(field)
|
|
if not field_models:
|
|
continue
|
|
for model in field_models:
|
|
model["updated"] = now
|
|
|
|
return prepare_for_save(call, fields, previous_task=previous_task)
|
|
|
|
|
|
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
|
with translate_errors_context(
|
|
field_does_not_exist_cls=errors.bad_request.ValidationError
|
|
), TimingContext("code", "parse_call"):
|
|
fields = prepare_create_fields(call, **kwargs)
|
|
task = task_bll.create(call, fields)
|
|
|
|
with TimingContext("code", "validate"):
|
|
task_bll.validate(task)
|
|
|
|
return task, fields
|
|
|
|
|
|
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
|
def validate(call: APICall, company_id, req_model: CreateRequest):
|
|
parent = call.data.get("parent")
|
|
if parent and parent.startswith(deleted_prefix):
|
|
call.data.pop("parent")
|
|
_validate_and_get_task_from_call(call)
|
|
|
|
|
|
def _update_cached_tags(company: str, project: str, fields: dict):
|
|
org_bll.update_tags(
|
|
company,
|
|
Tags.Task,
|
|
project=project,
|
|
tags=fields.get("tags"),
|
|
system_tags=fields.get("system_tags"),
|
|
)
|
|
|
|
|
|
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
|
org_bll.reset_tags(company, Tags.Task, projects=projects)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
|
|
)
|
|
def create(call: APICall, company_id, req_model: CreateRequest):
|
|
task, fields = _validate_and_get_task_from_call(call)
|
|
|
|
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
|
task.save()
|
|
_update_cached_tags(company_id, project=task.project, fields=fields)
|
|
update_project_time(task.project)
|
|
|
|
call.result.data_model = IdResponse(id=task.id)
|
|
|
|
|
|
@endpoint("tasks.clone", request_data_model=CloneRequest)
|
|
def clone_task(call: APICall, company_id, request: CloneRequest):
|
|
task, new_project = task_bll.clone_task(
|
|
company_id=company_id,
|
|
user_id=call.identity.user,
|
|
task_id=request.task,
|
|
name=request.new_task_name,
|
|
comment=request.new_task_comment,
|
|
parent=request.new_task_parent,
|
|
project=request.new_task_project,
|
|
tags=request.new_task_tags,
|
|
system_tags=request.new_task_system_tags,
|
|
hyperparams=request.new_task_hyperparams,
|
|
configuration=request.new_task_configuration,
|
|
container=request.new_task_container,
|
|
execution_overrides=request.execution_overrides,
|
|
input_models=request.new_task_input_models,
|
|
validate_references=request.validate_references,
|
|
new_project_name=request.new_project_name,
|
|
)
|
|
call.result.data = {
|
|
"id": task.id,
|
|
**({"new_project": new_project} if new_project else {}),
|
|
}
|
|
|
|
|
|
def prepare_update_fields(call: APICall, task, call_data):
|
|
valid_fields = deepcopy(Task.user_set_allowed())
|
|
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
|
update_fields["output__error"] = None
|
|
t_fields = task_fields
|
|
t_fields.add("output__error")
|
|
fields = parse_from_call(call_data, update_fields, t_fields)
|
|
return prepare_for_save(call, fields), valid_fields
|
|
|
|
|
|
@endpoint(
|
|
"tasks.update", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def update(call: APICall, company_id, req_model: UpdateRequest):
|
|
task_id = req_model.task
|
|
|
|
with translate_errors_context():
|
|
task = Task.get_for_writing(
|
|
id=task_id, company=company_id, _only=["id", "project"]
|
|
)
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
|
|
|
if not partial_update_dict:
|
|
return UpdateResponse(updated=0)
|
|
|
|
updated_count, updated_fields = Task.safe_update(
|
|
company_id=company_id,
|
|
id=task_id,
|
|
partial_update_dict=partial_update_dict,
|
|
injected_update=dict(last_change=datetime.utcnow()),
|
|
)
|
|
if updated_count:
|
|
new_project = updated_fields.get("project", task.project)
|
|
if new_project != task.project:
|
|
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
|
else:
|
|
_update_cached_tags(
|
|
company_id, project=task.project, fields=updated_fields
|
|
)
|
|
update_project_time(updated_fields.get("project"))
|
|
unprepare_from_saved(call, updated_fields)
|
|
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.set_requirements",
|
|
request_data_model=SetRequirementsRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def set_requirements(call: APICall, company_id, req_model: SetRequirementsRequest):
|
|
requirements = req_model.requirements
|
|
with translate_errors_context():
|
|
task = TaskBLL.get_task_with_access(
|
|
req_model.task,
|
|
company_id=company_id,
|
|
only=("status", "script"),
|
|
requires_write_access=True,
|
|
)
|
|
if not task.script:
|
|
raise errors.bad_request.MissingTaskFields(
|
|
"Task has no script field", task=task.id
|
|
)
|
|
res = update_task(task, update_cmds=dict(script__requirements=requirements))
|
|
call.result.data_model = UpdateResponse(updated=res)
|
|
if res:
|
|
call.result.data_model.fields = {"script.requirements": requirements}
|
|
|
|
|
|
@endpoint("tasks.update_batch")
|
|
def update_batch(call: APICall, company_id, _):
|
|
items = call.batched_data
|
|
if items is None:
|
|
raise errors.bad_request.BatchContainsNoItems()
|
|
|
|
with translate_errors_context():
|
|
items = {i["task"]: i for i in items}
|
|
tasks = {
|
|
t.id: t
|
|
for t in Task.get_many_for_writing(
|
|
company=company_id, query=Q(id__in=list(items))
|
|
)
|
|
}
|
|
|
|
if len(tasks) < len(items):
|
|
missing = tuple(set(items).difference(tasks))
|
|
raise errors.bad_request.InvalidTaskId(ids=missing)
|
|
|
|
now = datetime.utcnow()
|
|
|
|
bulk_ops = []
|
|
updated_projects = set()
|
|
for id, data in items.items():
|
|
task = tasks[id]
|
|
fields, valid_fields = prepare_update_fields(call, task, data)
|
|
partial_update_dict = Task.get_safe_update_dict(fields)
|
|
if not partial_update_dict:
|
|
continue
|
|
partial_update_dict.update(last_change=now)
|
|
update_op = UpdateOne(
|
|
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
|
)
|
|
bulk_ops.append(update_op)
|
|
|
|
new_project = partial_update_dict.get("project", task.project)
|
|
if new_project != task.project:
|
|
updated_projects.update({new_project, task.project})
|
|
elif any(f in partial_update_dict for f in ("tags", "system_tags")):
|
|
updated_projects.add(task.project)
|
|
|
|
updated = 0
|
|
if bulk_ops:
|
|
res = Task._get_collection().bulk_write(bulk_ops)
|
|
updated = res.modified_count
|
|
|
|
if updated and updated_projects:
|
|
projects = list(updated_projects)
|
|
_reset_cached_tags(company_id, projects=projects)
|
|
update_project_time(project_ids=projects)
|
|
|
|
call.result.data = {"updated": updated}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.edit", request_data_model=UpdateRequest, response_data_model=UpdateResponse
|
|
)
|
|
def edit(call: APICall, company_id, req_model: UpdateRequest):
|
|
task_id = req_model.task
|
|
force = req_model.force
|
|
|
|
with translate_errors_context():
|
|
task = Task.get_for_writing(id=task_id, company=company_id)
|
|
if not task:
|
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
|
|
|
if not force and task.status != TaskStatus.created:
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
expected=TaskStatus.created, status=task.status
|
|
)
|
|
|
|
edit_fields = create_fields.copy()
|
|
edit_fields.update(dict(status=None))
|
|
|
|
with translate_errors_context(
|
|
field_does_not_exist_cls=errors.bad_request.ValidationError
|
|
), TimingContext("code", "parse_and_validate"):
|
|
fields = prepare_create_fields(
|
|
call, valid_fields=edit_fields, output=task.output, previous_task=task
|
|
)
|
|
|
|
for key in fields:
|
|
field = getattr(task, key, None)
|
|
value = fields[key]
|
|
if (
|
|
field
|
|
and isinstance(value, dict)
|
|
and isinstance(field, EmbeddedDocument)
|
|
):
|
|
d = field.to_mongo(use_db_field=False).to_dict()
|
|
d.update(value)
|
|
fields[key] = d
|
|
|
|
task_bll.validate(task_bll.create(call, fields))
|
|
|
|
# make sure field names do not end in mongoengine comparison operators
|
|
fixed_fields = {
|
|
(k if k not in COMPARISON_OPERATORS else "%s__" % k): v
|
|
for k, v in fields.items()
|
|
}
|
|
if fixed_fields:
|
|
now = datetime.utcnow()
|
|
last_change = dict(last_change=now)
|
|
if not set(fields).issubset(Task.user_set_allowed()):
|
|
last_change.update(last_update=now)
|
|
fields.update(**last_change)
|
|
fixed_fields.update(**last_change)
|
|
updated = task.update(upsert=False, **fixed_fields)
|
|
if updated:
|
|
new_project = fixed_fields.get("project", task.project)
|
|
if new_project != task.project:
|
|
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
|
else:
|
|
_update_cached_tags(
|
|
company_id, project=task.project, fields=fixed_fields
|
|
)
|
|
update_project_time(fields.get("project"))
|
|
unprepare_from_saved(call, fields)
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
else:
|
|
call.result.data_model = UpdateResponse(updated=0)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
|
)
|
|
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
|
with translate_errors_context():
|
|
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
|
|
|
call.result.data = {
|
|
"params": [{"task": task, **data} for task, data in tasks_params.items()]
|
|
}
|
|
|
|
|
|
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
|
|
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"updated": HyperParams.edit_params(
|
|
company_id,
|
|
task_id=request.task,
|
|
hyperparams=request.hyperparams,
|
|
replace_hyperparams=request.replace_hyperparams,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
|
|
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"deleted": HyperParams.delete_params(
|
|
company_id,
|
|
task_id=request.task,
|
|
hyperparams=request.hyperparams,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
|
)
|
|
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
|
with translate_errors_context():
|
|
tasks_params = HyperParams.get_configurations(
|
|
company_id, task_ids=request.tasks, names=request.names
|
|
)
|
|
|
|
call.result.data = {
|
|
"configurations": [
|
|
{"task": task, **data} for task, data in tasks_params.items()
|
|
]
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
|
|
)
|
|
def get_configuration_names(
|
|
call: APICall, company_id, request: GetConfigurationNamesRequest
|
|
):
|
|
with translate_errors_context():
|
|
tasks_params = HyperParams.get_configuration_names(
|
|
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
|
)
|
|
|
|
call.result.data = {
|
|
"configurations": [
|
|
{"task": task, **data} for task, data in tasks_params.items()
|
|
]
|
|
}
|
|
|
|
|
|
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
|
|
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"updated": HyperParams.edit_configuration(
|
|
company_id,
|
|
task_id=request.task,
|
|
configuration=request.configuration,
|
|
replace_configuration=request.replace_configuration,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
|
|
def delete_configuration(
|
|
call: APICall, company_id, request: DeleteConfigurationRequest
|
|
):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"deleted": HyperParams.delete_configuration(
|
|
company_id,
|
|
task_id=request.task,
|
|
configuration=request.configuration,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.enqueue",
|
|
request_data_model=EnqueueRequest,
|
|
response_data_model=EnqueueResponse,
|
|
)
|
|
def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
|
queued, res = enqueue_task(
|
|
task_id=request.task,
|
|
company_id=company_id,
|
|
queue_id=request.queue,
|
|
status_message=request.status_message,
|
|
status_reason=request.status_reason,
|
|
)
|
|
call.result.data_model = EnqueueResponse(queued=queued, **res)
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class EnqueueRes:
|
|
queued: int = 0
|
|
|
|
def __add__(self, other: Tuple[int, dict]):
|
|
queued, _ = other
|
|
return EnqueueRes(queued=self.queued + queued)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.enqueue_many",
|
|
request_data_model=EnqueueManyRequest,
|
|
response_data_model=EnqueueManyResponse,
|
|
)
|
|
def enqueue_many(call: APICall, company_id, request: EnqueueManyRequest):
|
|
res, failures = run_batch_operation(
|
|
func=partial(
|
|
enqueue_task,
|
|
company_id=company_id,
|
|
queue_id=request.queue,
|
|
status_message=request.status_message,
|
|
status_reason=request.status_reason,
|
|
),
|
|
ids=request.ids,
|
|
init_res=EnqueueRes(),
|
|
)
|
|
call.result.data_model = EnqueueManyResponse(queued=res.queued, failures=failures)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.dequeue",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=DequeueResponse,
|
|
)
|
|
def dequeue(call: APICall, company_id, request: UpdateRequest):
|
|
task = TaskBLL.get_task_with_access(
|
|
request.task,
|
|
company_id=company_id,
|
|
only=("id", "execution", "status", "project"),
|
|
requires_write_access=True,
|
|
)
|
|
res = DequeueResponse(
|
|
**TaskBLL.dequeue_and_change_status(
|
|
task,
|
|
company_id,
|
|
status_message=request.status_message,
|
|
status_reason=request.status_reason,
|
|
)
|
|
)
|
|
|
|
res.dequeued = 1
|
|
call.result.data_model = res
|
|
|
|
|
|
@endpoint(
|
|
"tasks.reset", request_data_model=ResetRequest, response_data_model=ResetResponse
|
|
)
|
|
def reset(call: APICall, company_id, request: ResetRequest):
|
|
dequeued, cleanup_res, updates = reset_task(
|
|
task_id=request.task,
|
|
company_id=company_id,
|
|
force=request.force,
|
|
return_file_urls=request.return_file_urls,
|
|
delete_output_models=request.delete_output_models,
|
|
clear_all=request.clear_all,
|
|
)
|
|
|
|
res = ResetResponse(**updates, dequeued=dequeued)
|
|
# do not return artifacts since they are not serializable
|
|
res.fields.pop("execution.artifacts", None)
|
|
|
|
for key, value in attr.asdict(cleanup_res).items():
|
|
setattr(res, key, value)
|
|
|
|
call.result.data_model = res
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class ResetRes:
|
|
reset: int = 0
|
|
dequeued: int = 0
|
|
cleanup_res: CleanupResult = None
|
|
|
|
def __add__(self, other: Tuple[dict, CleanupResult, dict]):
|
|
dequeued, other_res, _ = other
|
|
dequeued = dequeued.get("removed", 0) if dequeued else 0
|
|
return ResetRes(
|
|
reset=self.reset + 1,
|
|
dequeued=self.dequeued + dequeued,
|
|
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
|
|
)
|
|
|
|
|
|
@endpoint("tasks.reset_many", request_data_model=ResetManyRequest)
|
|
def reset_many(call: APICall, company_id, request: ResetManyRequest):
|
|
res, failures = run_batch_operation(
|
|
func=partial(
|
|
reset_task,
|
|
company_id=company_id,
|
|
force=request.force,
|
|
return_file_urls=request.return_file_urls,
|
|
delete_output_models=request.delete_output_models,
|
|
clear_all=request.clear_all,
|
|
),
|
|
ids=request.ids,
|
|
init_res=ResetRes(),
|
|
)
|
|
|
|
if res.cleanup_res:
|
|
cleanup_res = dict(
|
|
deleted_models=res.cleanup_res.deleted_models,
|
|
urls=attr.asdict(res.cleanup_res.urls),
|
|
)
|
|
else:
|
|
cleanup_res = {}
|
|
call.result.data = dict(
|
|
reset=res.reset, dequeued=res.dequeued, **cleanup_res, failures=failures,
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.archive",
|
|
request_data_model=ArchiveRequest,
|
|
response_data_model=ArchiveResponse,
|
|
)
|
|
def archive(call: APICall, company_id, request: ArchiveRequest):
|
|
tasks = TaskBLL.assert_exists(
|
|
company_id,
|
|
task_ids=request.tasks,
|
|
only=("id", "execution", "status", "project", "system_tags"),
|
|
)
|
|
archived = 0
|
|
for task in tasks:
|
|
archived += archive_task(
|
|
company_id=company_id,
|
|
task=task,
|
|
status_message=request.status_message,
|
|
status_reason=request.status_reason,
|
|
)
|
|
|
|
call.result.data_model = ArchiveResponse(archived=archived)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.archive_many",
|
|
request_data_model=ArchiveManyRequest,
|
|
response_data_model=ArchiveManyResponse,
|
|
)
|
|
def archive_many(call: APICall, company_id, request: ArchiveManyRequest):
|
|
archived, failures = run_batch_operation(
|
|
func=partial(
|
|
archive_task,
|
|
company_id=company_id,
|
|
status_message=request.status_message,
|
|
status_reason=request.status_reason,
|
|
),
|
|
ids=request.ids,
|
|
init_res=0,
|
|
)
|
|
call.result.data_model = ArchiveManyResponse(archived=archived, failures=failures)
|
|
|
|
|
|
@endpoint("tasks.delete", request_data_model=DeleteRequest)
|
|
def delete(call: APICall, company_id, request: DeleteRequest):
|
|
deleted, task, cleanup_res = delete_task(
|
|
task_id=request.task,
|
|
company_id=company_id,
|
|
move_to_trash=request.move_to_trash,
|
|
force=request.force,
|
|
return_file_urls=request.return_file_urls,
|
|
delete_output_models=request.delete_output_models,
|
|
)
|
|
if deleted:
|
|
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
|
|
call.result.data = dict(deleted=bool(deleted), **attr.asdict(cleanup_res))
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class DeleteRes:
|
|
deleted: int = 0
|
|
projects: Set = set()
|
|
cleanup_res: CleanupResult = None
|
|
|
|
def __add__(self, other: Tuple[int, Task, CleanupResult]):
|
|
del_count, task, other_res = other
|
|
|
|
return DeleteRes(
|
|
deleted=self.deleted + del_count,
|
|
projects=self.projects | {task.project},
|
|
cleanup_res=self.cleanup_res + other_res if self.cleanup_res else other_res,
|
|
)
|
|
|
|
|
|
@endpoint("tasks.delete_many", request_data_model=DeleteManyRequest)
|
|
def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
|
res, failures = run_batch_operation(
|
|
func=partial(
|
|
delete_task,
|
|
company_id=company_id,
|
|
move_to_trash=request.move_to_trash,
|
|
force=request.force,
|
|
return_file_urls=request.return_file_urls,
|
|
delete_output_models=request.delete_output_models,
|
|
),
|
|
ids=request.ids,
|
|
init_res=DeleteRes(),
|
|
)
|
|
|
|
if res.deleted:
|
|
_reset_cached_tags(company_id, projects=list(res.projects))
|
|
|
|
cleanup_res = attr.asdict(res.cleanup_res) if res.cleanup_res else {}
|
|
call.result.data = dict(deleted=res.deleted, **cleanup_res, failures=failures)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.publish",
|
|
request_data_model=PublishRequest,
|
|
response_data_model=PublishResponse,
|
|
)
|
|
def publish(call: APICall, company_id, request: PublishRequest):
|
|
updates = publish_task(
|
|
task_id=request.task,
|
|
company_id=company_id,
|
|
force=request.force,
|
|
publish_model_func=ModelBLL.publish_model if request.publish_model else None,
|
|
status_reason=request.status_reason,
|
|
status_message=request.status_message,
|
|
)
|
|
call.result.data_model = PublishResponse(**updates)
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class PublishRes:
|
|
published: int = 0
|
|
|
|
def __add__(self, other: dict):
|
|
return PublishRes(published=self.published + 1)
|
|
|
|
|
|
@endpoint("tasks.publish_many", request_data_model=PublishManyRequest)
|
|
def publish_many(call: APICall, company_id, request: PublishManyRequest):
|
|
res, failures = run_batch_operation(
|
|
func=partial(
|
|
publish_task,
|
|
company_id=company_id,
|
|
force=request.force,
|
|
publish_model_func=ModelBLL.publish_model
|
|
if request.publish_model
|
|
else None,
|
|
status_reason=request.status_reason,
|
|
status_message=request.status_message,
|
|
),
|
|
ids=request.ids,
|
|
init_res=PublishRes(),
|
|
)
|
|
|
|
call.result.data = dict(published=res.published, failures=failures)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.completed",
|
|
min_version="2.2",
|
|
request_data_model=UpdateRequest,
|
|
response_data_model=UpdateResponse,
|
|
)
|
|
def completed(call: APICall, company_id, request: PublishRequest):
|
|
call.result.data_model = UpdateResponse(
|
|
**set_task_status_from_call(
|
|
request,
|
|
company_id,
|
|
new_status=TaskStatus.completed,
|
|
completed=datetime.utcnow(),
|
|
)
|
|
)
|
|
|
|
|
|
@endpoint("tasks.ping", request_data_model=PingRequest)
|
|
def ping(_, company_id, request: PingRequest):
|
|
TaskBLL.set_last_update(
|
|
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
|
)
|
|
|
|
|
|
@endpoint(
|
|
"tasks.add_or_update_artifacts",
|
|
min_version="2.10",
|
|
request_data_model=AddOrUpdateArtifactsRequest,
|
|
)
|
|
def add_or_update_artifacts(
|
|
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
|
):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"updated": Artifacts.add_or_update_artifacts(
|
|
company_id=company_id,
|
|
task_id=request.task,
|
|
artifacts=request.artifacts,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint(
|
|
"tasks.delete_artifacts",
|
|
min_version="2.10",
|
|
request_data_model=DeleteArtifactsRequest,
|
|
)
|
|
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
|
with translate_errors_context():
|
|
call.result.data = {
|
|
"deleted": Artifacts.delete_artifacts(
|
|
company_id=company_id,
|
|
task_id=request.task,
|
|
artifact_ids=request.artifacts,
|
|
force=request.force,
|
|
)
|
|
}
|
|
|
|
|
|
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Task.set_public(
|
|
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
|
)
|
|
|
|
|
|
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
|
with translate_errors_context():
|
|
call.result.data = Task.set_public(
|
|
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
|
)
|
|
|
|
|
|
@endpoint("tasks.move", request_data_model=MoveRequest)
|
|
def move(call: APICall, company_id: str, request: MoveRequest):
|
|
if not (request.project or request.project_name):
|
|
raise errors.bad_request.MissingRequiredFields(
|
|
"project or project_name is required"
|
|
)
|
|
|
|
updated_projects = set(
|
|
t.project for t in Task.objects(id__in=request.ids).only("project") if t.project
|
|
)
|
|
project_id = project_bll.move_under_project(
|
|
entity_cls=Task,
|
|
user=call.identity.user,
|
|
company=company_id,
|
|
ids=request.ids,
|
|
project=request.project,
|
|
project_name=request.project_name,
|
|
)
|
|
|
|
projects = list(updated_projects | {project_id})
|
|
_reset_cached_tags(company_id, projects=projects)
|
|
update_project_time(projects)
|
|
|
|
return {"project_id": project_id}
|
|
|
|
|
|
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
|
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
|
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
|
|
|
models_field = f"models__{request.type}"
|
|
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
|
query = {"id": request.task, f"{models_field}__name": request.name}
|
|
updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model})
|
|
|
|
updated = TaskBLL.update_statistics(
|
|
task_id=request.task,
|
|
company_id=company_id,
|
|
last_iteration_max=request.iteration,
|
|
**({f"push__{models_field}": model} if not updated else {}),
|
|
)
|
|
|
|
return {"updated": updated}
|
|
|
|
|
|
@endpoint("tasks.delete_models", min_version="2.13")
|
|
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
|
task = get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
|
|
|
delete_names = {
|
|
type_: [m.name for m in request.models if m.type == type_]
|
|
for type_ in get_options(ModelItemType)
|
|
}
|
|
commands = {
|
|
f"pull__models__{field}__name__in": names
|
|
for field, names in delete_names.items()
|
|
if names
|
|
}
|
|
|
|
updated = task.update(last_change=datetime.utcnow(), **commands,)
|
|
return {"updated": updated}
|