Optimize task artifacts

This commit is contained in:
allegroai 2021-01-05 16:40:35 +02:00
parent 89f81bfe5a
commit 171969c5ea
13 changed files with 499 additions and 249 deletions

View File

@ -1 +1 @@
__version__ = "2.9.0"
__version__ = "2.10.0"

View File

@ -7,7 +7,7 @@ from jsonmodels.validators import Enum, Length
from apiserver.apimodels import DictField, ListField
from apiserver.apimodels.base import UpdateResponse
from apiserver.database.model.task.task import TaskType
from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE
from apiserver.database.utils import get_options
@ -20,7 +20,9 @@ class ArtifactTypeData(models.Base):
class Artifact(models.Base):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(validators=Enum("input", "output"), default="output")
mode = StringField(
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
)
uri = StringField()
hash = StringField()
content_size = IntField()
@ -112,12 +114,18 @@ class CloneRequest(TaskRequest):
class AddOrUpdateArtifactsRequest(TaskRequest):
artifacts = ListField([Artifact], required=True)
artifacts = ListField([Artifact], validators=Length(minimum_value=1))
class AddOrUpdateArtifactsResponse(models.Base):
added = ListField([str])
updated = ListField([str])
class ArtifactId(models.Base):
key = StringField(required=True)
mode = StringField(
validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE
)
class DeleteArtifactsRequest(TaskRequest):
artifacts = ListField([ArtifactId], validators=Length(minimum_value=1))
class ResetRequest(UpdateRequest):

View File

@ -0,0 +1,85 @@
from datetime import datetime
from hashlib import md5
from operator import itemgetter
from typing import Sequence
from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId
from apiserver.bll.task.utils import get_task_for_update
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
def get_artifact_id(artifact: dict):
"""
Calculate id from 'key' and 'mode' fields
Return hash on on the id so that it will not contain mongo illegal characters
"""
key_hash: str = md5(artifact["key"].encode()).hexdigest()
mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE)
return f"{key_hash}_{mode}"
def artifacts_prepare_for_save(fields: dict):
artifacts_field = ("execution", "artifacts")
artifacts = nested_get(fields, artifacts_field)
if artifacts is None:
return
nested_set(
fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts}
)
def artifacts_unprepare_from_saved(fields):
artifacts_field = ("execution", "artifacts")
artifacts = nested_get(fields, artifacts_field)
if artifacts is None:
return
nested_set(
fields,
artifacts_field,
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
)
class Artifacts:
@classmethod
def add_or_update_artifacts(
cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact],
) -> int:
with TimingContext("mongo", "update_artifacts"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True
)
artifacts = {
get_artifact_id(a): Artifact(**a)
for a in (api_artifact.to_struct() for api_artifact in artifacts)
}
update_cmds = {
f"set__execution__artifacts__{name}": value
for name, value in artifacts.items()
}
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_artifacts(
cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId]
) -> int:
with TimingContext("mongo", "delete_artifacts"):
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=True
)
artifact_ids = [
get_artifact_id(a)
for a in (artifact_id.to_struct() for artifact_id in artifact_ids)
]
delete_cmds = {
f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())

View File

@ -13,8 +13,10 @@ from apiserver.apimodels.tasks import (
Configuration,
)
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update
from apiserver.config import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
log = config.logger(__file__)
@ -58,32 +60,33 @@ class HyperParams:
@classmethod
def delete_params(
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey]
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
@ -93,24 +96,25 @@ class HyperParams:
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
@ -152,28 +156,29 @@ class HyperParams:
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
for task in tasks
}
for task in tasks
}
@classmethod
def edit_configuration(
@ -183,47 +188,32 @@ class HyperParams:
configuration: Sequence[Configuration],
replace_configuration: bool,
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
with TimingContext("mongo", "edit_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration=Sequence[str]
cls, company_id: str, task_id: str, configuration: Sequence[str]
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
with TimingContext("mongo", "delete_configuration"):
task = get_task_for_update(company_id=company_id, task_id=task_id)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
@staticmethod
def _get_task_for_update(
company: str, id: str, allow_all_statuses: bool = False
) -> Task:
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=id)
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task
return task.update(**delete_cmds, last_update=datetime.utcnow())

View File

@ -1,20 +1,14 @@
from collections import OrderedDict
from datetime import datetime
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import apiserver.database.utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.apimodels.tasks import Artifact as ApiArtifact
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.config import config
from apiserver.database.errors import translate_errors_context
@ -28,14 +22,14 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
Artifact,
external_task_types,
)
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.service_repo import APICall
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import deep_merge
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change
@ -181,9 +175,6 @@ class TaskBLL(object):
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
params_dict = {
field: value
for field, value in (
@ -192,21 +183,32 @@ class TaskBLL(object):
)
if value is not None
}
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
for legacy_param in ("parameters", "configuration"):
legacy_value = execution_overrides.pop(legacy_param, None)
if legacy_value is not None:
params_dict["execution"] = legacy_value
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
execution_dict.update(execution_overrides)
params_prepare_for_save(params_dict, previous_task=task)
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
execution_dict["artifacts"] = {
k: a
for k, a in artifacts.items()
if a.get("mode") != ArtifactModes.output
}
now = datetime.utcnow()
with translate_errors_context():
@ -543,97 +545,6 @@ class TaskBLL(object):
force=force,
).execute()
@classmethod
def add_or_update_artifacts(
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
) -> Tuple[List[str], List[str]]:
key = attrgetter("key", "mode")
if not artifacts:
return [], []
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
artifacts: List[Artifact] = [
Artifact(**artifact.to_struct()) for artifact in artifacts
]
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
for retry in range(attempts):
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
current = list(map(key, task.execution.artifacts))
updated = [a for a in artifacts if key(a) in current]
added = [a for a in artifacts if a not in updated]
filter = {"_id": task_id, "company": company_id}
update = {}
array_filters = None
if current:
filter["execution.artifacts"] = {
"$size": len(current),
"$all": [
*(
{"$elemMatch": {"key": key, "mode": mode}}
for key, mode in current
)
],
}
else:
filter["$or"] = [
{"execution.artifacts": {"$exists": False}},
{"execution.artifacts": {"$size": 0}},
]
if added:
update["$push"] = {
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
}
if updated:
update["$set"] = {
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
for index, artifact in enumerate(updated)
}
array_filters = [
{
f"artifact{index}.key": artifact.key,
f"artifact{index}.mode": artifact.mode,
}
for index, artifact in enumerate(updated)
]
if not update:
return [], []
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
filter=filter,
update=update,
array_filters=array_filters,
upsert=False,
)
if result.matched_count >= 1:
break
wait_msec = random() * int(
config.get("services.tasks.artifacts.update_retry_msec", 500)
)
log.warning(
f"Failed to update artifacts for task {task_id} (updated by another party),"
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
)
sleep(wait_msec / 1000)
else:
raise errors.server_error.UpdateFailed(
"task artifacts updated by another party"
)
return [a.key for a in added], [a.key for a in updated]
@staticmethod
def get_aggregated_project_parameters(
company_id,

View File

@ -171,3 +171,23 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False
) -> Task:
"""
Loads only task id and return the task only if it is updatable (status == 'created')
"""
task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task

View File

@ -8,9 +8,4 @@ non_responsive_tasks_watchdog {
watch_interval_sec: 900
}
artifacts {
update_attempts: 10
update_retry_msec: 500
}
multi_task_histogram_limit: 100

View File

@ -1,3 +1,5 @@
from typing import Dict
from mongoengine import (
StringField,
EmbeddedDocumentField,
@ -14,7 +16,6 @@ from apiserver.database.fields import (
SafeMapField,
SafeDictField,
UnionField,
EmbeddedDocumentSortedListField,
SafeSortedListField,
)
from apiserver.database.model import AttributedDocument
@ -72,10 +73,13 @@ class ArtifactModes:
output = "output"
DEFAULT_ARTIFACT_MODE = ArtifactModes.output
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
uri = StringField()
hash = StringField()
content_size = LongField()
@ -107,7 +111,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
artifacts = EmbeddedDocumentSortedListField(Artifact)
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField()
queue = StringField()
""" Queue ID where task was queued """

View File

@ -0,0 +1,19 @@
from pymongo.database import Database, Collection
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.utilities.dicts import nested_get
def migrate_backend(db: Database):
collection: Collection = db["task"]
artifacts_field = "execution.artifacts"
query = {artifacts_field: {"$type": 4}}
for doc in collection.find(filter=query, projection=(artifacts_field,)):
artifacts = nested_get(doc, artifacts_field.split("."))
if not isinstance(artifacts, list):
continue
new_artifacts = {get_artifact_id(a): a for a in artifacts}
collection.update_one(
{"_id": doc["_id"]}, {"$set": {artifacts_field: new_artifacts}}
)

View File

@ -137,6 +137,14 @@ _definitions {
}
}
}
artifact_mode_enum {
type: string
enum: [
input
output
]
default: output
}
artifact {
type: object
required: [key, type]
@ -151,12 +159,7 @@ _definitions {
}
mode {
description: "System defined input/output indication"
type: string
enum: [
input
output
]
default: output
"$ref": "#/definitions/artifact_mode_enum"
}
uri {
description: "Raw data location"
@ -190,6 +193,20 @@ _definitions {
}
}
}
artifact_id {
type: object
required: [key]
properties {
key {
description: "Entry key"
type: string
}
mode {
description: "System defined input/output indication"
"$ref": "#/definitions/artifact_mode_enum"
}
}
}
execution {
type: object
properties {
@ -1552,11 +1569,11 @@ ping {
}
add_or_update_artifacts {
"2.6" {
description: """ Update an existing artifact (search by key/mode) or add a new one """
"2.10" {
description: """Update existing artifacts (search by key/mode) and add new ones"""
request {
type: object
required: [ task, artifacts ]
required: [task, artifacts]
properties {
task {
description: "Task ID"
@ -1565,22 +1582,45 @@ add_or_update_artifacts {
artifacts {
description: "Artifacts to add or update"
type: array
items { "$ref": "#/definitions/artifact" }
items {"$ref": "#/definitions/artifact"}
}
}
}
response {
type: object
properties {
added {
description: "Keys of artifacts added"
type: array
items { type: string }
}
updated {
description: "Keys of artifacts updated"
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_artifacts {
"2.10" {
description: """Delete existing artifacts (search by key/mode)"""
request {
type: object
required: [task, artifacts]
properties {
task {
description: "Task ID"
type: string
}
artifacts {
description: "Artifacts to delete"
type: array
items { type: string }
items {"$ref": "#/definitions/artifact_id"}
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}

View File

@ -29,7 +29,6 @@ from apiserver.apimodels.tasks import (
DequeueResponse,
CloneRequest,
AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
GetHyperParamsRequest,
@ -39,6 +38,7 @@ from apiserver.apimodels.tasks import (
EditConfigurationRequest,
DeleteConfigurationRequest,
GetConfigurationNamesRequest,
DeleteArtifactsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags
@ -49,6 +49,11 @@ from apiserver.bll.task import (
update_project_time,
split_by,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
artifacts_unprepare_from_saved,
Artifacts,
)
from apiserver.bll.task.hyperparams import HyperParams
from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from apiserver.bll.task.param_utils import (
@ -279,6 +284,7 @@ create_fields = {
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
conform_tag_fields(call, fields, validate=True)
params_prepare_for_save(fields, previous_task=previous_task)
artifacts_prepare_for_save(fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@ -301,10 +307,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
conform_output_tags(call, tasks_data)
for data in tasks_data:
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
params_unprepare_from_saved(
fields=data,
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
fields=data, copy_to_legacy=need_legacy_params,
)
artifacts_unprepare_from_saved(fields=data)
def prepare_create_fields(
@ -846,10 +853,15 @@ def reset(call: APICall, company_id, request: ResetRequest):
set__execution=Execution(), unset__script=1,
)
else:
updates.update(
unset__execution__queue=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
updates.update(unset__execution__queue=1)
if task.execution and task.execution.artifacts:
updates.update(
set__execution__artifacts={
key: artifact
for key, artifact in task.execution.artifacts
if artifact.get("mode") == "input"
}
)
res = ResetResponse(
**ChangeStatusRequest(
@ -1090,17 +1102,34 @@ def ping(_, company_id, request: PingRequest):
@endpoint(
"tasks.add_or_update_artifacts",
min_version="2.6",
min_version="2.10",
request_data_model=AddOrUpdateArtifactsRequest,
response_data_model=AddOrUpdateArtifactsResponse,
)
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
added, updated = TaskBLL.add_or_update_artifacts(
task_id=request.task, company_id=company_id, artifacts=request.artifacts
)
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
with translate_errors_context():
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id, task_id=request.task, artifacts=request.artifacts
)
}
@endpoint(
"tasks.delete_artifacts",
min_version="2.10",
request_data_model=DeleteArtifactsRequest,
)
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
with translate_errors_context():
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
)
}
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)

View File

@ -0,0 +1,100 @@
from operator import itemgetter
from typing import Sequence
from apiserver.tests.automated import TestService
class TestTasksArtifacts(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.10")
def new_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test artifacts",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)
def test_artifacts_set_get(self):
artifacts = [
dict(key="a", type="str", uri="test1"),
dict(key="b", type="int", uri="test2"),
]
# test create/get and get_all
task = self.new_task(execution={"artifacts": artifacts})
res = self.api.tasks.get_by_id(task=task).task
self._assertTaskArtifacts(artifacts, res)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
# test edit
artifacts = [
dict(key="bb", type="str", uri="test1", mode="output"),
dict(key="aa", type="int", uri="test2", mode="input"),
]
self.api.tasks.edit(task=task, execution={"artifacts": artifacts})
res = self.api.tasks.get_by_id(task=task).task
self._assertTaskArtifacts(artifacts, res)
# test clone
task2 = self.api.tasks.clone(task=task).id
res = self.api.tasks.get_by_id(task=task2).task
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
new_artifacts = [
dict(key="x", type="str", uri="x_test"),
dict(key="y", type="int", uri="y_test"),
dict(key="z", type="int", uri="y_test"),
]
new_task = self.api.tasks.clone(
task=task, execution_overrides={"artifacts": new_artifacts}
).id
res = self.api.tasks.get_by_id(task=new_task).task
self._assertTaskArtifacts(new_artifacts, res)
def test_artifacts_edit_delete(self):
artifacts = [
dict(key="a", type="str", uri="test1"),
dict(key="b", type="int", uri="test2"),
]
task = self.new_task(execution={"artifacts": artifacts})
# test add_or_update
edit = [
dict(key="a", type="str", uri="hello"),
dict(key="c", type="int", uri="world"),
]
res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
artifacts = self._update_source(artifacts, edit)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
# test delete
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
def _update_source(self, source: Sequence[dict], update: Sequence[dict]):
dict1 = {s["key"]: s for s in source}
dict2 = {u["key"]: u for u in update}
res = {
k: v if k not in dict2 else dict2[k]
for k, v in dict1.items()
}
res.update({k: v for k, v in dict2.items() if k not in res})
return list(res.values())
def _assertTaskArtifacts(self, artifacts: Sequence[dict], task):
task_artifacts: dict = task.execution.artifacts
self.assertEqual(len(artifacts), len(task_artifacts))
for expected, actual in zip(
sorted(artifacts, key=itemgetter("key", "type")), task_artifacts
):
self.assertEqual(
expected, {k: v for k, v in actual.items() if k in expected}
)

View File

@ -1,4 +1,4 @@
from typing import Sequence, Tuple, Any
from typing import Sequence, Tuple, Any, Union, Callable, Optional
def flatten_nested_items(
@ -33,3 +33,52 @@ def deep_merge(source: dict, override: dict) -> dict:
source[key] = value
return source
def nested_get(
dictionary: dict,
path: Union[Sequence[str], str],
default: Optional[Union[Any, Callable]] = None,
) -> Any:
if isinstance(path, str):
path = [path]
node = dictionary
for key in path:
if key not in node:
if callable(default):
return default()
return default
node = node.get(key)
return node
def nested_delete(dictionary: dict, path: Union[Sequence[str], str]) -> bool:
"""
Return 'True' if the element was deleted
"""
if isinstance(path, str):
path = [path]
*parent_path, last_key = path
parent = nested_get(dictionary, parent_path)
if not parent or last_key not in parent:
return False
del parent[last_key]
return True
def nested_set(dictionary: dict, path: Union[Sequence[str], str], value: Any):
if isinstance(path, str):
path = [path]
*parent_path, last_key = path
node = dictionary
for key in parent_path:
if key not in node:
node[key] = {}
node = node.get(key)
node[last_key] = value