mirror of
https://github.com/clearml/clearml-server
synced 2025-04-08 06:54:08 +00:00
Optimize task artifacts
This commit is contained in:
parent
89f81bfe5a
commit
171969c5ea
@ -1 +1 @@
|
||||
__version__ = "2.9.0"
|
||||
__version__ = "2.10.0"
|
||||
|
@ -7,7 +7,7 @@ from jsonmodels.validators import Enum, Length
|
||||
|
||||
from apiserver.apimodels import DictField, ListField
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.database.model.task.task import TaskType
|
||||
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
|
||||
@ -20,7 +20,9 @@ class ArtifactTypeData(models.Base):
|
||||
class Artifact(models.Base):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(validators=Enum("input", "output"), default="output")
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = IntField()
|
||||
@ -112,12 +114,18 @@ class CloneRequest(TaskRequest):
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
artifacts = ListField([Artifact], required=True)
|
||||
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsResponse(models.Base):
|
||||
added = ListField([str])
|
||||
updated = ListField([str])
|
||||
class ArtifactId(models.Base):
|
||||
key = StringField(required=True)
|
||||
mode = StringField(
|
||||
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||
)
|
||||
|
||||
|
||||
class DeleteArtifactsRequest(TaskRequest):
|
||||
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
|
85
apiserver/bll/task/artifacts.py
Normal file
85
apiserver/bll/task/artifacts.py
Normal file
@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from hashlib import md5
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||
from apiserver.bll.task.utils import get_task_for_update
|
||||
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
|
||||
|
||||
def get_artifact_id(artifact: dict):
|
||||
"""
|
||||
Calculate id from 'key' and 'mode' fields
|
||||
Return hash on on the id so that it will not contain mongo illegal characters
|
||||
"""
|
||||
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||
return f"{key_hash}_{mode}"
|
||||
|
||||
|
||||
def artifacts_prepare_for_save(fields: dict):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts}
|
||||
)
|
||||
|
||||
|
||||
def artifacts_unprepare_from_saved(fields):
|
||||
artifacts_field = ("execution", "artifacts")
|
||||
artifacts = nested_get(fields, artifacts_field)
|
||||
if artifacts is None:
|
||||
return
|
||||
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||
)
|
||||
|
||||
|
||||
class Artifacts:
|
||||
@classmethod
|
||||
def add_or_update_artifacts(
|
||||
cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact],
|
||||
) -> int:
|
||||
with TimingContext("mongo", "update_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, allow_all_statuses=True
|
||||
)
|
||||
|
||||
artifacts = {
|
||||
get_artifact_id(a): Artifact(**a)
|
||||
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||
}
|
||||
|
||||
update_cmds = {
|
||||
f"set__execution__artifacts__{name}": value
|
||||
for name, value in artifacts.items()
|
||||
}
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_artifacts(
|
||||
cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId]
|
||||
) -> int:
|
||||
with TimingContext("mongo", "delete_artifacts"):
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, allow_all_statuses=True
|
||||
)
|
||||
|
||||
artifact_ids = [
|
||||
get_artifact_id(a)
|
||||
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||
]
|
||||
delete_cmds = {
|
||||
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
@ -13,8 +13,10 @@ from apiserver.apimodels.tasks import (
|
||||
Configuration,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import get_task_for_update
|
||||
from apiserver.config import config
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -58,32 +60,33 @@ class HyperParams:
|
||||
|
||||
@classmethod
|
||||
def delete_params(
|
||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
||||
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey]
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
with TimingContext("mongo", "delete_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
with_param, without_param = iterutils.partition(
|
||||
hyperparams, key=lambda p: bool(p.name)
|
||||
)
|
||||
sections_to_delete = {p.section for p in without_param}
|
||||
delete_cmds = {
|
||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||
for section in sections_to_delete
|
||||
}
|
||||
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
for item in with_param:
|
||||
section = ParameterKeyEscaper.escape(item.section)
|
||||
if item.section in sections_to_delete:
|
||||
raise errors.bad_request.FieldsConflict(
|
||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||
)
|
||||
name = ParameterKeyEscaper.escape(item.name)
|
||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def edit_params(
|
||||
@ -93,24 +96,25 @@ class HyperParams:
|
||||
hyperparams: Sequence[HyperParamItem],
|
||||
replace_hyperparams: str,
|
||||
) -> int:
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = cls._get_task_for_update(
|
||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
with TimingContext("mongo", "edit_hyperparams"):
|
||||
properties_only = cls._normalize_params(hyperparams)
|
||||
task = get_task_for_update(
|
||||
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
|
||||
)
|
||||
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
update_cmds = dict()
|
||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||
if replace_hyperparams == ReplaceHyperparams.all:
|
||||
update_cmds["set__hyperparams"] = hyperparams
|
||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||
for section, value in hyperparams.items():
|
||||
update_cmds[f"set__hyperparams__{section}"] = value
|
||||
else:
|
||||
for section, section_params in hyperparams.items():
|
||||
for name, value in section_params.items():
|
||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||
@ -152,28 +156,29 @@ class HyperParams:
|
||||
def get_configuration_names(
|
||||
cls, company_id: str, task_ids: Sequence[str]
|
||||
) -> Dict[str, list]:
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
with TimingContext("mongo", "get_configuration_names"):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"_id": {"$in": task_ids},
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
}
|
||||
},
|
||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||
{"$unwind": "$items"},
|
||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||
]
|
||||
|
||||
tasks = Task.aggregate(pipeline)
|
||||
|
||||
return {
|
||||
task["_id"]: {
|
||||
"names": sorted(
|
||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||
)
|
||||
for task in tasks
|
||||
}
|
||||
for task in tasks
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_configuration(
|
||||
@ -183,47 +188,32 @@ class HyperParams:
|
||||
configuration: Sequence[Configuration],
|
||||
replace_configuration: bool,
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
with TimingContext("mongo", "edit_configuration"):
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id)
|
||||
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
update_cmds = dict()
|
||||
configuration = {
|
||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||
for c in configuration
|
||||
}
|
||||
if replace_configuration:
|
||||
update_cmds["set__configuration"] = configuration
|
||||
else:
|
||||
for name, value in configuration.items():
|
||||
update_cmds[f"set__configuration__{name}"] = value
|
||||
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@classmethod
|
||||
def delete_configuration(
|
||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
||||
cls, company_id: str, task_id: str, configuration: Sequence[str]
|
||||
) -> int:
|
||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
||||
with TimingContext("mongo", "delete_configuration"):
|
||||
task = get_task_for_update(company_id=company_id, task_id=task_id)
|
||||
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
delete_cmds = {
|
||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||
for name in set(configuration)
|
||||
}
|
||||
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
||||
@staticmethod
|
||||
def _get_task_for_update(
|
||||
company: str, id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||
|
@ -1,20 +1,14 @@
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from random import random
|
||||
from time import sleep
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import pymongo.results
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from six import string_types
|
||||
|
||||
import apiserver.database.utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.config import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@ -28,14 +22,14 @@ from apiserver.database.model.task.task import (
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
external_task_types,
|
||||
)
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.dicts import deep_merge
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
@ -181,9 +175,6 @@ class TaskBLL(object):
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
params_dict = {
|
||||
field: value
|
||||
for field, value in (
|
||||
@ -192,21 +183,32 @@ class TaskBLL(object):
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
for legacy_param in ("parameters", "configuration"):
|
||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
||||
]
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
}
|
||||
now = datetime.utcnow()
|
||||
|
||||
with translate_errors_context():
|
||||
@ -543,97 +545,6 @@ class TaskBLL(object):
|
||||
force=force,
|
||||
).execute()
|
||||
|
||||
@classmethod
|
||||
def add_or_update_artifacts(
|
||||
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
key = attrgetter("key", "mode")
|
||||
|
||||
if not artifacts:
|
||||
return [], []
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
|
||||
artifacts: List[Artifact] = [
|
||||
Artifact(**artifact.to_struct()) for artifact in artifacts
|
||||
]
|
||||
|
||||
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
|
||||
|
||||
for retry in range(attempts):
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
|
||||
current = list(map(key, task.execution.artifacts))
|
||||
updated = [a for a in artifacts if key(a) in current]
|
||||
added = [a for a in artifacts if a not in updated]
|
||||
|
||||
filter = {"_id": task_id, "company": company_id}
|
||||
update = {}
|
||||
array_filters = None
|
||||
if current:
|
||||
filter["execution.artifacts"] = {
|
||||
"$size": len(current),
|
||||
"$all": [
|
||||
*(
|
||||
{"$elemMatch": {"key": key, "mode": mode}}
|
||||
for key, mode in current
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
filter["$or"] = [
|
||||
{"execution.artifacts": {"$exists": False}},
|
||||
{"execution.artifacts": {"$size": 0}},
|
||||
]
|
||||
|
||||
if added:
|
||||
update["$push"] = {
|
||||
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
|
||||
}
|
||||
if updated:
|
||||
update["$set"] = {
|
||||
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
|
||||
for index, artifact in enumerate(updated)
|
||||
}
|
||||
array_filters = [
|
||||
{
|
||||
f"artifact{index}.key": artifact.key,
|
||||
f"artifact{index}.mode": artifact.mode,
|
||||
}
|
||||
for index, artifact in enumerate(updated)
|
||||
]
|
||||
|
||||
if not update:
|
||||
return [], []
|
||||
|
||||
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
|
||||
filter=filter,
|
||||
update=update,
|
||||
array_filters=array_filters,
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if result.matched_count >= 1:
|
||||
break
|
||||
|
||||
wait_msec = random() * int(
|
||||
config.get("services.tasks.artifacts.update_retry_msec", 500)
|
||||
)
|
||||
|
||||
log.warning(
|
||||
f"Failed to update artifacts for task {task_id} (updated by another party),"
|
||||
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
|
||||
)
|
||||
|
||||
sleep(wait_msec / 1000)
|
||||
else:
|
||||
raise errors.server_error.UpdateFailed(
|
||||
"task artifacts updated by another party"
|
||||
)
|
||||
|
||||
return [a.key for a in added], [a.key for a in updated]
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
|
@ -171,3 +171,23 @@ def split_by(
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
"""
|
||||
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
if allow_all_statuses:
|
||||
return task
|
||||
|
||||
if task.status != TaskStatus.created:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
expected=TaskStatus.created, status=task.status
|
||||
)
|
||||
return task
|
||||
|
@ -8,9 +8,4 @@ non_responsive_tasks_watchdog {
|
||||
watch_interval_sec: 900
|
||||
}
|
||||
|
||||
artifacts {
|
||||
update_attempts: 10
|
||||
update_retry_msec: 500
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
@ -14,7 +16,6 @@ from apiserver.database.fields import (
|
||||
SafeMapField,
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
@ -72,10 +73,13 @@ class ArtifactModes:
|
||||
output = "output"
|
||||
|
||||
|
||||
DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
@ -107,7 +111,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
||||
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
""" Queue ID where task was queued """
|
||||
|
19
apiserver/mongo/migrations/0.17.0.py
Normal file
19
apiserver/mongo/migrations/0.17.0.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
from apiserver.bll.task.artifacts import get_artifact_id
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collection: Collection = db["task"]
|
||||
artifacts_field = "execution.artifacts"
|
||||
query = {artifacts_field: {"$type": 4}}
|
||||
for doc in collection.find(filter=query, projection=(artifacts_field,)):
|
||||
artifacts = nested_get(doc, artifacts_field.split("."))
|
||||
if not isinstance(artifacts, list):
|
||||
continue
|
||||
|
||||
new_artifacts = {get_artifact_id(a): a for a in artifacts}
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {artifacts_field: new_artifacts}}
|
||||
)
|
@ -137,6 +137,14 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_mode_enum {
|
||||
type: string
|
||||
enum: [
|
||||
input
|
||||
output
|
||||
]
|
||||
default: output
|
||||
}
|
||||
artifact {
|
||||
type: object
|
||||
required: [key, type]
|
||||
@ -151,12 +159,7 @@ _definitions {
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
type: string
|
||||
enum: [
|
||||
input
|
||||
output
|
||||
]
|
||||
default: output
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
uri {
|
||||
description: "Raw data location"
|
||||
@ -190,6 +193,20 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
artifact_id {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
key {
|
||||
description: "Entry key"
|
||||
type: string
|
||||
}
|
||||
mode {
|
||||
description: "System defined input/output indication"
|
||||
"$ref": "#/definitions/artifact_mode_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
@ -1552,11 +1569,11 @@ ping {
|
||||
}
|
||||
|
||||
add_or_update_artifacts {
|
||||
"2.6" {
|
||||
description: """ Update an existing artifact (search by key/mode) or add a new one """
|
||||
"2.10" {
|
||||
description: """Update existing artifacts (search by key/mode) and add new ones"""
|
||||
request {
|
||||
type: object
|
||||
required: [ task, artifacts ]
|
||||
required: [task, artifacts]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
@ -1565,22 +1582,45 @@ add_or_update_artifacts {
|
||||
artifacts {
|
||||
description: "Artifacts to add or update"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/artifact" }
|
||||
items {"$ref": "#/definitions/artifact"}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
added {
|
||||
description: "Keys of artifacts added"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
updated {
|
||||
description: "Keys of artifacts updated"
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_artifacts {
|
||||
"2.10" {
|
||||
description: """Delete existing artifacts (search by key/mode)"""
|
||||
request {
|
||||
type: object
|
||||
required: [task, artifacts]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
artifacts {
|
||||
description: "Artifacts to delete"
|
||||
type: array
|
||||
items { type: string }
|
||||
items {"$ref": "#/definitions/artifact_id"}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates if the task was updated successfully"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,6 @@ from apiserver.apimodels.tasks import (
|
||||
DequeueResponse,
|
||||
CloneRequest,
|
||||
AddOrUpdateArtifactsRequest,
|
||||
AddOrUpdateArtifactsResponse,
|
||||
GetTypesRequest,
|
||||
ResetRequest,
|
||||
GetHyperParamsRequest,
|
||||
@ -39,6 +38,7 @@ from apiserver.apimodels.tasks import (
|
||||
EditConfigurationRequest,
|
||||
DeleteConfigurationRequest,
|
||||
GetConfigurationNamesRequest,
|
||||
DeleteArtifactsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -49,6 +49,11 @@ from apiserver.bll.task import (
|
||||
update_project_time,
|
||||
split_by,
|
||||
)
|
||||
from apiserver.bll.task.artifacts import (
|
||||
artifacts_prepare_for_save,
|
||||
artifacts_unprepare_from_saved,
|
||||
Artifacts,
|
||||
)
|
||||
from apiserver.bll.task.hyperparams import HyperParams
|
||||
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||
from apiserver.bll.task.param_utils import (
|
||||
@ -279,6 +284,7 @@ create_fields = {
|
||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
params_prepare_for_save(fields, previous_task=previous_task)
|
||||
artifacts_prepare_for_save(fields)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_fields:
|
||||
@ -301,10 +307,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
||||
conform_output_tags(call, tasks_data)
|
||||
|
||||
for data in tasks_data:
|
||||
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
||||
params_unprepare_from_saved(
|
||||
fields=data,
|
||||
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
|
||||
fields=data, copy_to_legacy=need_legacy_params,
|
||||
)
|
||||
artifacts_unprepare_from_saved(fields=data)
|
||||
|
||||
|
||||
def prepare_create_fields(
|
||||
@ -846,10 +853,15 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(
|
||||
unset__execution__queue=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
updates.update(unset__execution__queue=1)
|
||||
if task.execution and task.execution.artifacts:
|
||||
updates.update(
|
||||
set__execution__artifacts={
|
||||
key: artifact
|
||||
for key, artifact in task.execution.artifacts
|
||||
if artifact.get("mode") == "input"
|
||||
}
|
||||
)
|
||||
|
||||
res = ResetResponse(
|
||||
**ChangeStatusRequest(
|
||||
@ -1090,17 +1102,34 @@ def ping(_, company_id, request: PingRequest):
|
||||
|
||||
@endpoint(
|
||||
"tasks.add_or_update_artifacts",
|
||||
min_version="2.6",
|
||||
min_version="2.10",
|
||||
request_data_model=AddOrUpdateArtifactsRequest,
|
||||
response_data_model=AddOrUpdateArtifactsResponse,
|
||||
)
|
||||
def add_or_update_artifacts(
|
||||
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||
):
|
||||
added, updated = TaskBLL.add_or_update_artifacts(
|
||||
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
||||
)
|
||||
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id, task_id=request.task, artifacts=request.artifacts
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.delete_artifacts",
|
||||
min_version="2.10",
|
||||
request_data_model=DeleteArtifactsRequest,
|
||||
)
|
||||
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
|
100
apiserver/tests/automated/test_task_artifacts.py
Normal file
100
apiserver/tests/automated/test_task_artifacts.py
Normal file
@ -0,0 +1,100 @@
|
||||
from operator import itemgetter
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTasksArtifacts(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.10")
|
||||
|
||||
def new_task(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test artifacts",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def test_artifacts_set_get(self):
|
||||
artifacts = [
|
||||
dict(key="a", type="str", uri="test1"),
|
||||
dict(key="b", type="int", uri="test2"),
|
||||
]
|
||||
|
||||
# test create/get and get_all
|
||||
task = self.new_task(execution={"artifacts": artifacts})
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
|
||||
# test edit
|
||||
artifacts = [
|
||||
dict(key="bb", type="str", uri="test1", mode="output"),
|
||||
dict(key="aa", type="int", uri="test2", mode="input"),
|
||||
]
|
||||
self.api.tasks.edit(task=task, execution={"artifacts": artifacts})
|
||||
res = self.api.tasks.get_by_id(task=task).task
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
|
||||
# test clone
|
||||
task2 = self.api.tasks.clone(task=task).id
|
||||
res = self.api.tasks.get_by_id(task=task2).task
|
||||
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
||||
|
||||
new_artifacts = [
|
||||
dict(key="x", type="str", uri="x_test"),
|
||||
dict(key="y", type="int", uri="y_test"),
|
||||
dict(key="z", type="int", uri="y_test"),
|
||||
]
|
||||
new_task = self.api.tasks.clone(
|
||||
task=task, execution_overrides={"artifacts": new_artifacts}
|
||||
).id
|
||||
res = self.api.tasks.get_by_id(task=new_task).task
|
||||
self._assertTaskArtifacts(new_artifacts, res)
|
||||
|
||||
def test_artifacts_edit_delete(self):
|
||||
artifacts = [
|
||||
dict(key="a", type="str", uri="test1"),
|
||||
dict(key="b", type="int", uri="test2"),
|
||||
]
|
||||
task = self.new_task(execution={"artifacts": artifacts})
|
||||
|
||||
# test add_or_update
|
||||
edit = [
|
||||
dict(key="a", type="str", uri="hello"),
|
||||
dict(key="c", type="int", uri="world"),
|
||||
]
|
||||
res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||
artifacts = self._update_source(artifacts, edit)
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
|
||||
# test delete
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
|
||||
|
||||
def _update_source(self, source: Sequence[dict], update: Sequence[dict]):
|
||||
dict1 = {s["key"]: s for s in source}
|
||||
dict2 = {u["key"]: u for u in update}
|
||||
res = {
|
||||
k: v if k not in dict2 else dict2[k]
|
||||
for k, v in dict1.items()
|
||||
}
|
||||
res.update({k: v for k, v in dict2.items() if k not in res})
|
||||
return list(res.values())
|
||||
|
||||
def _assertTaskArtifacts(self, artifacts: Sequence[dict], task):
|
||||
task_artifacts: dict = task.execution.artifacts
|
||||
self.assertEqual(len(artifacts), len(task_artifacts))
|
||||
|
||||
for expected, actual in zip(
|
||||
sorted(artifacts, key=itemgetter("key", "type")), task_artifacts
|
||||
):
|
||||
self.assertEqual(
|
||||
expected, {k: v for k, v in actual.items() if k in expected}
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
from typing import Sequence, Tuple, Any
|
||||
from typing import Sequence, Tuple, Any, Union, Callable, Optional
|
||||
|
||||
|
||||
def flatten_nested_items(
|
||||
@ -33,3 +33,52 @@ def deep_merge(source: dict, override: dict) -> dict:
|
||||
source[key] = value
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def nested_get(
|
||||
dictionary: dict,
|
||||
path: Union[Sequence[str], str],
|
||||
default: Optional[Union[Any, Callable]] = None,
|
||||
) -> Any:
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
|
||||
node = dictionary
|
||||
for key in path:
|
||||
if key not in node:
|
||||
if callable(default):
|
||||
return default()
|
||||
return default
|
||||
node = node.get(key)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def nested_delete(dictionary: dict, path: Union[Sequence[str], str]) -> bool:
|
||||
"""
|
||||
Return 'True' if the element was deleted
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
|
||||
*parent_path, last_key = path
|
||||
parent = nested_get(dictionary, parent_path)
|
||||
if not parent or last_key not in parent:
|
||||
return False
|
||||
|
||||
del parent[last_key]
|
||||
return True
|
||||
|
||||
|
||||
def nested_set(dictionary: dict, path: Union[Sequence[str], str], value: Any):
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
|
||||
*parent_path, last_key = path
|
||||
node = dictionary
|
||||
for key in parent_path:
|
||||
if key not in node:
|
||||
node[key] = {}
|
||||
node = node.get(key)
|
||||
|
||||
node[last_key] = value
|
||||
|
Loading…
Reference in New Issue
Block a user