mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Optimize task artifacts
This commit is contained in:
parent
89f81bfe5a
commit
171969c5ea
@ -1 +1 @@
|
|||||||
__version__ = "2.9.0"
|
__version__ = "2.10.0"
|
||||||
|
@ -7,7 +7,7 @@ from jsonmodels.validators import Enum, Length
|
|||||||
|
|
||||||
from apiserver.apimodels import DictField, ListField
|
from apiserver.apimodels import DictField, ListField
|
||||||
from apiserver.apimodels.base import UpdateResponse
|
from apiserver.apimodels.base import UpdateResponse
|
||||||
from apiserver.database.model.task.task import TaskType
|
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
|
||||||
from apiserver.database.utils import get_options
|
from apiserver.database.utils import get_options
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,9 @@ class ArtifactTypeData(models.Base):
|
|||||||
class Artifact(models.Base):
|
class Artifact(models.Base):
|
||||||
key = StringField(required=True)
|
key = StringField(required=True)
|
||||||
type = StringField(required=True)
|
type = StringField(required=True)
|
||||||
mode = StringField(validators=Enum("input", "output"), default="output")
|
mode = StringField(
|
||||||
|
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||||
|
)
|
||||||
uri = StringField()
|
uri = StringField()
|
||||||
hash = StringField()
|
hash = StringField()
|
||||||
content_size = IntField()
|
content_size = IntField()
|
||||||
@ -112,12 +114,18 @@ class CloneRequest(TaskRequest):
|
|||||||
|
|
||||||
|
|
||||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||||
artifacts = ListField([Artifact], required=True)
|
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
|
||||||
|
|
||||||
|
|
||||||
class AddOrUpdateArtifactsResponse(models.Base):
|
class ArtifactId(models.Base):
|
||||||
added = ListField([str])
|
key = StringField(required=True)
|
||||||
updated = ListField([str])
|
mode = StringField(
|
||||||
|
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteArtifactsRequest(TaskRequest):
|
||||||
|
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
|
||||||
|
|
||||||
|
|
||||||
class ResetRequest(UpdateRequest):
|
class ResetRequest(UpdateRequest):
|
||||||
|
85
apiserver/bll/task/artifacts.py
Normal file
85
apiserver/bll/task/artifacts.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from hashlib import md5
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
|
||||||
|
from apiserver.bll.task.utils import get_task_for_update
|
||||||
|
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
|
||||||
|
from apiserver.timing_context import TimingContext
|
||||||
|
from apiserver.utilities.dicts import nested_get, nested_set
|
||||||
|
|
||||||
|
|
||||||
|
def get_artifact_id(artifact: dict):
|
||||||
|
"""
|
||||||
|
Calculate id from 'key' and 'mode' fields
|
||||||
|
Return hash on on the id so that it will not contain mongo illegal characters
|
||||||
|
"""
|
||||||
|
key_hash: str = md5(artifact["key"].encode()).hexdigest()
|
||||||
|
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
|
||||||
|
return f"{key_hash}_{mode}"
|
||||||
|
|
||||||
|
|
||||||
|
def artifacts_prepare_for_save(fields: dict):
|
||||||
|
artifacts_field = ("execution", "artifacts")
|
||||||
|
artifacts = nested_get(fields, artifacts_field)
|
||||||
|
if artifacts is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
nested_set(
|
||||||
|
fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def artifacts_unprepare_from_saved(fields):
|
||||||
|
artifacts_field = ("execution", "artifacts")
|
||||||
|
artifacts = nested_get(fields, artifacts_field)
|
||||||
|
if artifacts is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
nested_set(
|
||||||
|
fields,
|
||||||
|
artifacts_field,
|
||||||
|
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Artifacts:
|
||||||
|
@classmethod
|
||||||
|
def add_or_update_artifacts(
|
||||||
|
cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact],
|
||||||
|
) -> int:
|
||||||
|
with TimingContext("mongo", "update_artifacts"):
|
||||||
|
task = get_task_for_update(
|
||||||
|
company_id=company_id, task_id=task_id, allow_all_statuses=True
|
||||||
|
)
|
||||||
|
|
||||||
|
artifacts = {
|
||||||
|
get_artifact_id(a): Artifact(**a)
|
||||||
|
for a in (api_artifact.to_struct() for api_artifact in artifacts)
|
||||||
|
}
|
||||||
|
|
||||||
|
update_cmds = {
|
||||||
|
f"set__execution__artifacts__{name}": value
|
||||||
|
for name, value in artifacts.items()
|
||||||
|
}
|
||||||
|
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def delete_artifacts(
|
||||||
|
cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId]
|
||||||
|
) -> int:
|
||||||
|
with TimingContext("mongo", "delete_artifacts"):
|
||||||
|
task = get_task_for_update(
|
||||||
|
company_id=company_id, task_id=task_id, allow_all_statuses=True
|
||||||
|
)
|
||||||
|
|
||||||
|
artifact_ids = [
|
||||||
|
get_artifact_id(a)
|
||||||
|
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
|
||||||
|
]
|
||||||
|
delete_cmds = {
|
||||||
|
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
@ -13,8 +13,10 @@ from apiserver.apimodels.tasks import (
|
|||||||
Configuration,
|
Configuration,
|
||||||
)
|
)
|
||||||
from apiserver.bll.task import TaskBLL
|
from apiserver.bll.task import TaskBLL
|
||||||
|
from apiserver.bll.task.utils import get_task_for_update
|
||||||
from apiserver.config import config
|
from apiserver.config import config
|
||||||
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
|
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
|
||||||
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
@ -58,32 +60,33 @@ class HyperParams:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_params(
|
def delete_params(
|
||||||
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
|
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey]
|
||||||
) -> int:
|
) -> int:
|
||||||
properties_only = cls._normalize_params(hyperparams)
|
with TimingContext("mongo", "delete_hyperparams"):
|
||||||
task = cls._get_task_for_update(
|
properties_only = cls._normalize_params(hyperparams)
|
||||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
task = get_task_for_update(
|
||||||
)
|
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
|
||||||
|
)
|
||||||
|
|
||||||
with_param, without_param = iterutils.partition(
|
with_param, without_param = iterutils.partition(
|
||||||
hyperparams, key=lambda p: bool(p.name)
|
hyperparams, key=lambda p: bool(p.name)
|
||||||
)
|
)
|
||||||
sections_to_delete = {p.section for p in without_param}
|
sections_to_delete = {p.section for p in without_param}
|
||||||
delete_cmds = {
|
delete_cmds = {
|
||||||
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
|
||||||
for section in sections_to_delete
|
for section in sections_to_delete
|
||||||
}
|
}
|
||||||
|
|
||||||
for item in with_param:
|
for item in with_param:
|
||||||
section = ParameterKeyEscaper.escape(item.section)
|
section = ParameterKeyEscaper.escape(item.section)
|
||||||
if item.section in sections_to_delete:
|
if item.section in sections_to_delete:
|
||||||
raise errors.bad_request.FieldsConflict(
|
raise errors.bad_request.FieldsConflict(
|
||||||
"Cannot delete section field if the whole section was scheduled for deletion"
|
"Cannot delete section field if the whole section was scheduled for deletion"
|
||||||
)
|
)
|
||||||
name = ParameterKeyEscaper.escape(item.name)
|
name = ParameterKeyEscaper.escape(item.name)
|
||||||
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
|
||||||
|
|
||||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def edit_params(
|
def edit_params(
|
||||||
@ -93,24 +96,25 @@ class HyperParams:
|
|||||||
hyperparams: Sequence[HyperParamItem],
|
hyperparams: Sequence[HyperParamItem],
|
||||||
replace_hyperparams: str,
|
replace_hyperparams: str,
|
||||||
) -> int:
|
) -> int:
|
||||||
properties_only = cls._normalize_params(hyperparams)
|
with TimingContext("mongo", "edit_hyperparams"):
|
||||||
task = cls._get_task_for_update(
|
properties_only = cls._normalize_params(hyperparams)
|
||||||
company=company_id, id=task_id, allow_all_statuses=properties_only
|
task = get_task_for_update(
|
||||||
)
|
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
|
||||||
|
)
|
||||||
|
|
||||||
update_cmds = dict()
|
update_cmds = dict()
|
||||||
hyperparams = cls._db_dicts_from_list(hyperparams)
|
hyperparams = cls._db_dicts_from_list(hyperparams)
|
||||||
if replace_hyperparams == ReplaceHyperparams.all:
|
if replace_hyperparams == ReplaceHyperparams.all:
|
||||||
update_cmds["set__hyperparams"] = hyperparams
|
update_cmds["set__hyperparams"] = hyperparams
|
||||||
elif replace_hyperparams == ReplaceHyperparams.section:
|
elif replace_hyperparams == ReplaceHyperparams.section:
|
||||||
for section, value in hyperparams.items():
|
for section, value in hyperparams.items():
|
||||||
update_cmds[f"set__hyperparams__{section}"] = value
|
update_cmds[f"set__hyperparams__{section}"] = value
|
||||||
else:
|
else:
|
||||||
for section, section_params in hyperparams.items():
|
for section, section_params in hyperparams.items():
|
||||||
for name, value in section_params.items():
|
for name, value in section_params.items():
|
||||||
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
update_cmds[f"set__hyperparams__{section}__{name}"] = value
|
||||||
|
|
||||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
|
||||||
@ -152,28 +156,29 @@ class HyperParams:
|
|||||||
def get_configuration_names(
|
def get_configuration_names(
|
||||||
cls, company_id: str, task_ids: Sequence[str]
|
cls, company_id: str, task_ids: Sequence[str]
|
||||||
) -> Dict[str, list]:
|
) -> Dict[str, list]:
|
||||||
pipeline = [
|
with TimingContext("mongo", "get_configuration_names"):
|
||||||
{
|
pipeline = [
|
||||||
"$match": {
|
{
|
||||||
"company": {"$in": [None, "", company_id]},
|
"$match": {
|
||||||
"_id": {"$in": task_ids},
|
"company": {"$in": [None, "", company_id]},
|
||||||
|
"_id": {"$in": task_ids},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
||||||
|
{"$unwind": "$items"},
|
||||||
|
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = Task.aggregate(pipeline)
|
||||||
|
|
||||||
|
return {
|
||||||
|
task["_id"]: {
|
||||||
|
"names": sorted(
|
||||||
|
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
},
|
for task in tasks
|
||||||
{"$project": {"items": {"$objectToArray": "$configuration"}}},
|
|
||||||
{"$unwind": "$items"},
|
|
||||||
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
|
|
||||||
]
|
|
||||||
|
|
||||||
tasks = Task.aggregate(pipeline)
|
|
||||||
|
|
||||||
return {
|
|
||||||
task["_id"]: {
|
|
||||||
"names": sorted(
|
|
||||||
ParameterKeyEscaper.unescape(name) for name in task["names"]
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
for task in tasks
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def edit_configuration(
|
def edit_configuration(
|
||||||
@ -183,47 +188,32 @@ class HyperParams:
|
|||||||
configuration: Sequence[Configuration],
|
configuration: Sequence[Configuration],
|
||||||
replace_configuration: bool,
|
replace_configuration: bool,
|
||||||
) -> int:
|
) -> int:
|
||||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
with TimingContext("mongo", "edit_configuration"):
|
||||||
|
task = get_task_for_update(company_id=company_id, task_id=task_id)
|
||||||
|
|
||||||
update_cmds = dict()
|
update_cmds = dict()
|
||||||
configuration = {
|
configuration = {
|
||||||
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
|
||||||
for c in configuration
|
for c in configuration
|
||||||
}
|
}
|
||||||
if replace_configuration:
|
if replace_configuration:
|
||||||
update_cmds["set__configuration"] = configuration
|
update_cmds["set__configuration"] = configuration
|
||||||
else:
|
else:
|
||||||
for name, value in configuration.items():
|
for name, value in configuration.items():
|
||||||
update_cmds[f"set__configuration__{name}"] = value
|
update_cmds[f"set__configuration__{name}"] = value
|
||||||
|
|
||||||
return task.update(**update_cmds, last_update=datetime.utcnow())
|
return task.update(**update_cmds, last_update=datetime.utcnow())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_configuration(
|
def delete_configuration(
|
||||||
cls, company_id: str, task_id: str, configuration=Sequence[str]
|
cls, company_id: str, task_id: str, configuration: Sequence[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
task = cls._get_task_for_update(company=company_id, id=task_id)
|
with TimingContext("mongo", "delete_configuration"):
|
||||||
|
task = get_task_for_update(company_id=company_id, task_id=task_id)
|
||||||
|
|
||||||
delete_cmds = {
|
delete_cmds = {
|
||||||
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
|
||||||
for name in set(configuration)
|
for name in set(configuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
return task.update(**delete_cmds, last_update=datetime.utcnow())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_task_for_update(
|
|
||||||
company: str, id: str, allow_all_statuses: bool = False
|
|
||||||
) -> Task:
|
|
||||||
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
|
|
||||||
if not task:
|
|
||||||
raise errors.bad_request.InvalidTaskId(id=id)
|
|
||||||
|
|
||||||
if allow_all_statuses:
|
|
||||||
return task
|
|
||||||
|
|
||||||
if task.status != TaskStatus.created:
|
|
||||||
raise errors.bad_request.InvalidTaskStatus(
|
|
||||||
expected=TaskStatus.created, status=task.status
|
|
||||||
)
|
|
||||||
return task
|
|
||||||
|
@ -1,20 +1,14 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from operator import attrgetter
|
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||||
from random import random
|
|
||||||
from time import sleep
|
|
||||||
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
|
|
||||||
|
|
||||||
import dpath
|
import dpath
|
||||||
import pymongo.results
|
|
||||||
import six
|
import six
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
import apiserver.database.utils as dbutils
|
import apiserver.database.utils as dbutils
|
||||||
from apiserver.es_factory import es_factory
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.tasks import Artifact as ApiArtifact
|
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.config import config
|
from apiserver.config import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
@ -28,14 +22,14 @@ from apiserver.database.model.task.task import (
|
|||||||
TaskStatusMessage,
|
TaskStatusMessage,
|
||||||
TaskSystemTags,
|
TaskSystemTags,
|
||||||
ArtifactModes,
|
ArtifactModes,
|
||||||
Artifact,
|
|
||||||
external_task_types,
|
external_task_types,
|
||||||
)
|
)
|
||||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||||
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.dicts import deep_merge
|
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
|
from .artifacts import artifacts_prepare_for_save
|
||||||
from .param_utils import params_prepare_for_save
|
from .param_utils import params_prepare_for_save
|
||||||
from .utils import ChangeStatusRequest, validate_status_change
|
from .utils import ChangeStatusRequest, validate_status_change
|
||||||
|
|
||||||
@ -181,9 +175,6 @@ class TaskBLL(object):
|
|||||||
execution_overrides: Optional[dict] = None,
|
execution_overrides: Optional[dict] = None,
|
||||||
validate_references: bool = False,
|
validate_references: bool = False,
|
||||||
) -> Task:
|
) -> Task:
|
||||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
|
||||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
|
||||||
execution_model_overriden = False
|
|
||||||
params_dict = {
|
params_dict = {
|
||||||
field: value
|
field: value
|
||||||
for field, value in (
|
for field, value in (
|
||||||
@ -192,21 +183,32 @@ class TaskBLL(object):
|
|||||||
)
|
)
|
||||||
if value is not None
|
if value is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||||
|
|
||||||
|
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||||
|
execution_model_overriden = False
|
||||||
if execution_overrides:
|
if execution_overrides:
|
||||||
|
execution_model_overriden = execution_overrides.get("model") is not None
|
||||||
|
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||||
|
|
||||||
params_dict["execution"] = {}
|
params_dict["execution"] = {}
|
||||||
for legacy_param in ("parameters", "configuration"):
|
for legacy_param in ("parameters", "configuration"):
|
||||||
legacy_value = execution_overrides.pop(legacy_param, None)
|
legacy_value = execution_overrides.pop(legacy_param, None)
|
||||||
if legacy_value is not None:
|
if legacy_value is not None:
|
||||||
params_dict["execution"] = legacy_value
|
params_dict["execution"] = legacy_value
|
||||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
|
||||||
execution_model_overriden = execution_overrides.get("model") is not None
|
execution_dict.update(execution_overrides)
|
||||||
|
|
||||||
params_prepare_for_save(params_dict, previous_task=task)
|
params_prepare_for_save(params_dict, previous_task=task)
|
||||||
|
|
||||||
artifacts = execution_dict.get("artifacts")
|
artifacts = execution_dict.get("artifacts")
|
||||||
if artifacts:
|
if artifacts:
|
||||||
execution_dict["artifacts"] = [
|
execution_dict["artifacts"] = {
|
||||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
k: a
|
||||||
]
|
for k, a in artifacts.items()
|
||||||
|
if a.get("mode") != ArtifactModes.output
|
||||||
|
}
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
@ -543,97 +545,6 @@ class TaskBLL(object):
|
|||||||
force=force,
|
force=force,
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_or_update_artifacts(
|
|
||||||
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
|
|
||||||
) -> Tuple[List[str], List[str]]:
|
|
||||||
key = attrgetter("key", "mode")
|
|
||||||
|
|
||||||
if not artifacts:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
|
|
||||||
artifacts: List[Artifact] = [
|
|
||||||
Artifact(**artifact.to_struct()) for artifact in artifacts
|
|
||||||
]
|
|
||||||
|
|
||||||
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
|
|
||||||
|
|
||||||
for retry in range(attempts):
|
|
||||||
task = cls.get_task_with_access(
|
|
||||||
task_id, company_id=company_id, requires_write_access=True
|
|
||||||
)
|
|
||||||
|
|
||||||
current = list(map(key, task.execution.artifacts))
|
|
||||||
updated = [a for a in artifacts if key(a) in current]
|
|
||||||
added = [a for a in artifacts if a not in updated]
|
|
||||||
|
|
||||||
filter = {"_id": task_id, "company": company_id}
|
|
||||||
update = {}
|
|
||||||
array_filters = None
|
|
||||||
if current:
|
|
||||||
filter["execution.artifacts"] = {
|
|
||||||
"$size": len(current),
|
|
||||||
"$all": [
|
|
||||||
*(
|
|
||||||
{"$elemMatch": {"key": key, "mode": mode}}
|
|
||||||
for key, mode in current
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
filter["$or"] = [
|
|
||||||
{"execution.artifacts": {"$exists": False}},
|
|
||||||
{"execution.artifacts": {"$size": 0}},
|
|
||||||
]
|
|
||||||
|
|
||||||
if added:
|
|
||||||
update["$push"] = {
|
|
||||||
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
|
|
||||||
}
|
|
||||||
if updated:
|
|
||||||
update["$set"] = {
|
|
||||||
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
|
|
||||||
for index, artifact in enumerate(updated)
|
|
||||||
}
|
|
||||||
array_filters = [
|
|
||||||
{
|
|
||||||
f"artifact{index}.key": artifact.key,
|
|
||||||
f"artifact{index}.mode": artifact.mode,
|
|
||||||
}
|
|
||||||
for index, artifact in enumerate(updated)
|
|
||||||
]
|
|
||||||
|
|
||||||
if not update:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
|
|
||||||
filter=filter,
|
|
||||||
update=update,
|
|
||||||
array_filters=array_filters,
|
|
||||||
upsert=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count >= 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
wait_msec = random() * int(
|
|
||||||
config.get("services.tasks.artifacts.update_retry_msec", 500)
|
|
||||||
)
|
|
||||||
|
|
||||||
log.warning(
|
|
||||||
f"Failed to update artifacts for task {task_id} (updated by another party),"
|
|
||||||
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
|
|
||||||
)
|
|
||||||
|
|
||||||
sleep(wait_msec / 1000)
|
|
||||||
else:
|
|
||||||
raise errors.server_error.UpdateFailed(
|
|
||||||
"task artifacts updated by another party"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [a.key for a in added], [a.key for a in updated]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_aggregated_project_parameters(
|
def get_aggregated_project_parameters(
|
||||||
company_id,
|
company_id,
|
||||||
|
@ -171,3 +171,23 @@ def split_by(
|
|||||||
[item for cond, item in applied if cond],
|
[item for cond, item in applied if cond],
|
||||||
[item for cond, item in applied if not cond],
|
[item for cond, item in applied if not cond],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_for_update(
|
||||||
|
company_id: str, task_id: str, allow_all_statuses: bool = False
|
||||||
|
) -> Task:
|
||||||
|
"""
|
||||||
|
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||||
|
"""
|
||||||
|
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
|
||||||
|
if not task:
|
||||||
|
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||||
|
|
||||||
|
if allow_all_statuses:
|
||||||
|
return task
|
||||||
|
|
||||||
|
if task.status != TaskStatus.created:
|
||||||
|
raise errors.bad_request.InvalidTaskStatus(
|
||||||
|
expected=TaskStatus.created, status=task.status
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
@ -8,9 +8,4 @@ non_responsive_tasks_watchdog {
|
|||||||
watch_interval_sec: 900
|
watch_interval_sec: 900
|
||||||
}
|
}
|
||||||
|
|
||||||
artifacts {
|
|
||||||
update_attempts: 10
|
|
||||||
update_retry_msec: 500
|
|
||||||
}
|
|
||||||
|
|
||||||
multi_task_histogram_limit: 100
|
multi_task_histogram_limit: 100
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from mongoengine import (
|
from mongoengine import (
|
||||||
StringField,
|
StringField,
|
||||||
EmbeddedDocumentField,
|
EmbeddedDocumentField,
|
||||||
@ -14,7 +16,6 @@ from apiserver.database.fields import (
|
|||||||
SafeMapField,
|
SafeMapField,
|
||||||
SafeDictField,
|
SafeDictField,
|
||||||
UnionField,
|
UnionField,
|
||||||
EmbeddedDocumentSortedListField,
|
|
||||||
SafeSortedListField,
|
SafeSortedListField,
|
||||||
)
|
)
|
||||||
from apiserver.database.model import AttributedDocument
|
from apiserver.database.model import AttributedDocument
|
||||||
@ -72,10 +73,13 @@ class ArtifactModes:
|
|||||||
output = "output"
|
output = "output"
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ARTIFACT_MODE = ArtifactModes.output
|
||||||
|
|
||||||
|
|
||||||
class Artifact(EmbeddedDocument):
|
class Artifact(EmbeddedDocument):
|
||||||
key = StringField(required=True)
|
key = StringField(required=True)
|
||||||
type = StringField(required=True)
|
type = StringField(required=True)
|
||||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
|
||||||
uri = StringField()
|
uri = StringField()
|
||||||
hash = StringField()
|
hash = StringField()
|
||||||
content_size = LongField()
|
content_size = LongField()
|
||||||
@ -107,7 +111,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
|||||||
model_desc = SafeMapField(StringField(default=""))
|
model_desc = SafeMapField(StringField(default=""))
|
||||||
model_labels = ModelLabels()
|
model_labels = ModelLabels()
|
||||||
framework = StringField()
|
framework = StringField()
|
||||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||||
docker_cmd = StringField()
|
docker_cmd = StringField()
|
||||||
queue = StringField()
|
queue = StringField()
|
||||||
""" Queue ID where task was queued """
|
""" Queue ID where task was queued """
|
||||||
|
19
apiserver/mongo/migrations/0.17.0.py
Normal file
19
apiserver/mongo/migrations/0.17.0.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pymongo.database import Database, Collection
|
||||||
|
|
||||||
|
from apiserver.bll.task.artifacts import get_artifact_id
|
||||||
|
from apiserver.utilities.dicts import nested_get
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_backend(db: Database):
|
||||||
|
collection: Collection = db["task"]
|
||||||
|
artifacts_field = "execution.artifacts"
|
||||||
|
query = {artifacts_field: {"$type": 4}}
|
||||||
|
for doc in collection.find(filter=query, projection=(artifacts_field,)):
|
||||||
|
artifacts = nested_get(doc, artifacts_field.split("."))
|
||||||
|
if not isinstance(artifacts, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_artifacts = {get_artifact_id(a): a for a in artifacts}
|
||||||
|
collection.update_one(
|
||||||
|
{"_id": doc["_id"]}, {"$set": {artifacts_field: new_artifacts}}
|
||||||
|
)
|
@ -137,6 +137,14 @@ _definitions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
artifact_mode_enum {
|
||||||
|
type: string
|
||||||
|
enum: [
|
||||||
|
input
|
||||||
|
output
|
||||||
|
]
|
||||||
|
default: output
|
||||||
|
}
|
||||||
artifact {
|
artifact {
|
||||||
type: object
|
type: object
|
||||||
required: [key, type]
|
required: [key, type]
|
||||||
@ -151,12 +159,7 @@ _definitions {
|
|||||||
}
|
}
|
||||||
mode {
|
mode {
|
||||||
description: "System defined input/output indication"
|
description: "System defined input/output indication"
|
||||||
type: string
|
"$ref": "#/definitions/artifact_mode_enum"
|
||||||
enum: [
|
|
||||||
input
|
|
||||||
output
|
|
||||||
]
|
|
||||||
default: output
|
|
||||||
}
|
}
|
||||||
uri {
|
uri {
|
||||||
description: "Raw data location"
|
description: "Raw data location"
|
||||||
@ -190,6 +193,20 @@ _definitions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
artifact_id {
|
||||||
|
type: object
|
||||||
|
required: [key]
|
||||||
|
properties {
|
||||||
|
key {
|
||||||
|
description: "Entry key"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
mode {
|
||||||
|
description: "System defined input/output indication"
|
||||||
|
"$ref": "#/definitions/artifact_mode_enum"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
execution {
|
execution {
|
||||||
type: object
|
type: object
|
||||||
properties {
|
properties {
|
||||||
@ -1552,11 +1569,11 @@ ping {
|
|||||||
}
|
}
|
||||||
|
|
||||||
add_or_update_artifacts {
|
add_or_update_artifacts {
|
||||||
"2.6" {
|
"2.10" {
|
||||||
description: """ Update an existing artifact (search by key/mode) or add a new one """
|
description: """Update existing artifacts (search by key/mode) and add new ones"""
|
||||||
request {
|
request {
|
||||||
type: object
|
type: object
|
||||||
required: [ task, artifacts ]
|
required: [task, artifacts]
|
||||||
properties {
|
properties {
|
||||||
task {
|
task {
|
||||||
description: "Task ID"
|
description: "Task ID"
|
||||||
@ -1565,22 +1582,45 @@ add_or_update_artifacts {
|
|||||||
artifacts {
|
artifacts {
|
||||||
description: "Artifacts to add or update"
|
description: "Artifacts to add or update"
|
||||||
type: array
|
type: array
|
||||||
items { "$ref": "#/definitions/artifact" }
|
items {"$ref": "#/definitions/artifact"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
type: object
|
type: object
|
||||||
properties {
|
properties {
|
||||||
added {
|
|
||||||
description: "Keys of artifacts added"
|
|
||||||
type: array
|
|
||||||
items { type: string }
|
|
||||||
}
|
|
||||||
updated {
|
updated {
|
||||||
description: "Keys of artifacts updated"
|
description: "Indicates if the task was updated successfully"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete_artifacts {
|
||||||
|
"2.10" {
|
||||||
|
description: """Delete existing artifacts (search by key/mode)"""
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [task, artifacts]
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
artifacts {
|
||||||
|
description: "Artifacts to delete"
|
||||||
type: array
|
type: array
|
||||||
items { type: string }
|
items {"$ref": "#/definitions/artifact_id"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
deleted {
|
||||||
|
description: "Indicates if the task was updated successfully"
|
||||||
|
type: integer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,6 @@ from apiserver.apimodels.tasks import (
|
|||||||
DequeueResponse,
|
DequeueResponse,
|
||||||
CloneRequest,
|
CloneRequest,
|
||||||
AddOrUpdateArtifactsRequest,
|
AddOrUpdateArtifactsRequest,
|
||||||
AddOrUpdateArtifactsResponse,
|
|
||||||
GetTypesRequest,
|
GetTypesRequest,
|
||||||
ResetRequest,
|
ResetRequest,
|
||||||
GetHyperParamsRequest,
|
GetHyperParamsRequest,
|
||||||
@ -39,6 +38,7 @@ from apiserver.apimodels.tasks import (
|
|||||||
EditConfigurationRequest,
|
EditConfigurationRequest,
|
||||||
DeleteConfigurationRequest,
|
DeleteConfigurationRequest,
|
||||||
GetConfigurationNamesRequest,
|
GetConfigurationNamesRequest,
|
||||||
|
DeleteArtifactsRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
@ -49,6 +49,11 @@ from apiserver.bll.task import (
|
|||||||
update_project_time,
|
update_project_time,
|
||||||
split_by,
|
split_by,
|
||||||
)
|
)
|
||||||
|
from apiserver.bll.task.artifacts import (
|
||||||
|
artifacts_prepare_for_save,
|
||||||
|
artifacts_unprepare_from_saved,
|
||||||
|
Artifacts,
|
||||||
|
)
|
||||||
from apiserver.bll.task.hyperparams import HyperParams
|
from apiserver.bll.task.hyperparams import HyperParams
|
||||||
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
|
||||||
from apiserver.bll.task.param_utils import (
|
from apiserver.bll.task.param_utils import (
|
||||||
@ -279,6 +284,7 @@ create_fields = {
|
|||||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
params_prepare_for_save(fields, previous_task=previous_task)
|
params_prepare_for_save(fields, previous_task=previous_task)
|
||||||
|
artifacts_prepare_for_save(fields)
|
||||||
|
|
||||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||||
for field in task_script_fields:
|
for field in task_script_fields:
|
||||||
@ -301,10 +307,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
|||||||
conform_output_tags(call, tasks_data)
|
conform_output_tags(call, tasks_data)
|
||||||
|
|
||||||
for data in tasks_data:
|
for data in tasks_data:
|
||||||
|
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
||||||
params_unprepare_from_saved(
|
params_unprepare_from_saved(
|
||||||
fields=data,
|
fields=data, copy_to_legacy=need_legacy_params,
|
||||||
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
|
|
||||||
)
|
)
|
||||||
|
artifacts_unprepare_from_saved(fields=data)
|
||||||
|
|
||||||
|
|
||||||
def prepare_create_fields(
|
def prepare_create_fields(
|
||||||
@ -846,10 +853,15 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
|||||||
set__execution=Execution(), unset__script=1,
|
set__execution=Execution(), unset__script=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
updates.update(
|
updates.update(unset__execution__queue=1)
|
||||||
unset__execution__queue=1,
|
if task.execution and task.execution.artifacts:
|
||||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
updates.update(
|
||||||
)
|
set__execution__artifacts={
|
||||||
|
key: artifact
|
||||||
|
for key, artifact in task.execution.artifacts
|
||||||
|
if artifact.get("mode") == "input"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
res = ResetResponse(
|
res = ResetResponse(
|
||||||
**ChangeStatusRequest(
|
**ChangeStatusRequest(
|
||||||
@ -1090,17 +1102,34 @@ def ping(_, company_id, request: PingRequest):
|
|||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"tasks.add_or_update_artifacts",
|
"tasks.add_or_update_artifacts",
|
||||||
min_version="2.6",
|
min_version="2.10",
|
||||||
request_data_model=AddOrUpdateArtifactsRequest,
|
request_data_model=AddOrUpdateArtifactsRequest,
|
||||||
response_data_model=AddOrUpdateArtifactsResponse,
|
|
||||||
)
|
)
|
||||||
def add_or_update_artifacts(
|
def add_or_update_artifacts(
|
||||||
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||||
):
|
):
|
||||||
added, updated = TaskBLL.add_or_update_artifacts(
|
with translate_errors_context():
|
||||||
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
call.result.data = {
|
||||||
)
|
"updated": Artifacts.add_or_update_artifacts(
|
||||||
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|
company_id=company_id, task_id=request.task, artifacts=request.artifacts
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"tasks.delete_artifacts",
|
||||||
|
min_version="2.10",
|
||||||
|
request_data_model=DeleteArtifactsRequest,
|
||||||
|
)
|
||||||
|
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
||||||
|
with translate_errors_context():
|
||||||
|
call.result.data = {
|
||||||
|
"deleted": Artifacts.delete_artifacts(
|
||||||
|
company_id=company_id,
|
||||||
|
task_id=request.task,
|
||||||
|
artifact_ids=request.artifacts,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||||
|
100
apiserver/tests/automated/test_task_artifacts.py
Normal file
100
apiserver/tests/automated/test_task_artifacts.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from operator import itemgetter
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
|
class TestTasksArtifacts(TestService):
|
||||||
|
def setUp(self, **kwargs):
|
||||||
|
super().setUp(version="2.10")
|
||||||
|
|
||||||
|
def new_task(self, **kwargs) -> str:
|
||||||
|
self.update_missing(
|
||||||
|
kwargs,
|
||||||
|
type="testing",
|
||||||
|
name="test artifacts",
|
||||||
|
input=dict(view=dict()),
|
||||||
|
delete_params=dict(force=True),
|
||||||
|
)
|
||||||
|
return self.create_temp("tasks", **kwargs)
|
||||||
|
|
||||||
|
def test_artifacts_set_get(self):
|
||||||
|
artifacts = [
|
||||||
|
dict(key="a", type="str", uri="test1"),
|
||||||
|
dict(key="b", type="int", uri="test2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# test create/get and get_all
|
||||||
|
task = self.new_task(execution={"artifacts": artifacts})
|
||||||
|
res = self.api.tasks.get_by_id(task=task).task
|
||||||
|
self._assertTaskArtifacts(artifacts, res)
|
||||||
|
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||||
|
self._assertTaskArtifacts(artifacts, res)
|
||||||
|
|
||||||
|
# test edit
|
||||||
|
artifacts = [
|
||||||
|
dict(key="bb", type="str", uri="test1", mode="output"),
|
||||||
|
dict(key="aa", type="int", uri="test2", mode="input"),
|
||||||
|
]
|
||||||
|
self.api.tasks.edit(task=task, execution={"artifacts": artifacts})
|
||||||
|
res = self.api.tasks.get_by_id(task=task).task
|
||||||
|
self._assertTaskArtifacts(artifacts, res)
|
||||||
|
|
||||||
|
# test clone
|
||||||
|
task2 = self.api.tasks.clone(task=task).id
|
||||||
|
res = self.api.tasks.get_by_id(task=task2).task
|
||||||
|
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
||||||
|
|
||||||
|
new_artifacts = [
|
||||||
|
dict(key="x", type="str", uri="x_test"),
|
||||||
|
dict(key="y", type="int", uri="y_test"),
|
||||||
|
dict(key="z", type="int", uri="y_test"),
|
||||||
|
]
|
||||||
|
new_task = self.api.tasks.clone(
|
||||||
|
task=task, execution_overrides={"artifacts": new_artifacts}
|
||||||
|
).id
|
||||||
|
res = self.api.tasks.get_by_id(task=new_task).task
|
||||||
|
self._assertTaskArtifacts(new_artifacts, res)
|
||||||
|
|
||||||
|
def test_artifacts_edit_delete(self):
|
||||||
|
artifacts = [
|
||||||
|
dict(key="a", type="str", uri="test1"),
|
||||||
|
dict(key="b", type="int", uri="test2"),
|
||||||
|
]
|
||||||
|
task = self.new_task(execution={"artifacts": artifacts})
|
||||||
|
|
||||||
|
# test add_or_update
|
||||||
|
edit = [
|
||||||
|
dict(key="a", type="str", uri="hello"),
|
||||||
|
dict(key="c", type="int", uri="world"),
|
||||||
|
]
|
||||||
|
res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||||
|
artifacts = self._update_source(artifacts, edit)
|
||||||
|
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||||
|
self._assertTaskArtifacts(artifacts, res)
|
||||||
|
|
||||||
|
# test delete
|
||||||
|
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||||
|
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||||
|
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
|
||||||
|
|
||||||
|
def _update_source(self, source: Sequence[dict], update: Sequence[dict]):
|
||||||
|
dict1 = {s["key"]: s for s in source}
|
||||||
|
dict2 = {u["key"]: u for u in update}
|
||||||
|
res = {
|
||||||
|
k: v if k not in dict2 else dict2[k]
|
||||||
|
for k, v in dict1.items()
|
||||||
|
}
|
||||||
|
res.update({k: v for k, v in dict2.items() if k not in res})
|
||||||
|
return list(res.values())
|
||||||
|
|
||||||
|
def _assertTaskArtifacts(self, artifacts: Sequence[dict], task):
|
||||||
|
task_artifacts: dict = task.execution.artifacts
|
||||||
|
self.assertEqual(len(artifacts), len(task_artifacts))
|
||||||
|
|
||||||
|
for expected, actual in zip(
|
||||||
|
sorted(artifacts, key=itemgetter("key", "type")), task_artifacts
|
||||||
|
):
|
||||||
|
self.assertEqual(
|
||||||
|
expected, {k: v for k, v in actual.items() if k in expected}
|
||||||
|
)
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Sequence, Tuple, Any
|
from typing import Sequence, Tuple, Any, Union, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
def flatten_nested_items(
|
def flatten_nested_items(
|
||||||
@ -33,3 +33,52 @@ def deep_merge(source: dict, override: dict) -> dict:
|
|||||||
source[key] = value
|
source[key] = value
|
||||||
|
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
|
||||||
|
def nested_get(
|
||||||
|
dictionary: dict,
|
||||||
|
path: Union[Sequence[str], str],
|
||||||
|
default: Optional[Union[Any, Callable]] = None,
|
||||||
|
) -> Any:
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
|
||||||
|
node = dictionary
|
||||||
|
for key in path:
|
||||||
|
if key not in node:
|
||||||
|
if callable(default):
|
||||||
|
return default()
|
||||||
|
return default
|
||||||
|
node = node.get(key)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def nested_delete(dictionary: dict, path: Union[Sequence[str], str]) -> bool:
|
||||||
|
"""
|
||||||
|
Return 'True' if the element was deleted
|
||||||
|
"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
|
||||||
|
*parent_path, last_key = path
|
||||||
|
parent = nested_get(dictionary, parent_path)
|
||||||
|
if not parent or last_key not in parent:
|
||||||
|
return False
|
||||||
|
|
||||||
|
del parent[last_key]
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def nested_set(dictionary: dict, path: Union[Sequence[str], str], value: Any):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = [path]
|
||||||
|
|
||||||
|
*parent_path, last_key = path
|
||||||
|
node = dictionary
|
||||||
|
for key in parent_path:
|
||||||
|
if key not in node:
|
||||||
|
node[key] = {}
|
||||||
|
node = node.get(key)
|
||||||
|
|
||||||
|
node[last_key] = value
|
||||||
|
Loading…
Reference in New Issue
Block a user