diff --git a/apiserver/api_version.py b/apiserver/api_version.py index 43ce13d..1c62222 100644 --- a/apiserver/api_version.py +++ b/apiserver/api_version.py @@ -1 +1 @@ -__version__ = "2.9.0" +__version__ = "2.10.0" diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 99b3935..d6894f4 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -7,7 +7,7 @@ from jsonmodels.validators import Enum, Length from apiserver.apimodels import DictField, ListField from apiserver.apimodels.base import UpdateResponse -from apiserver.database.model.task.task import TaskType +from apiserver.database.model.task.task import TaskType, ArtifactModes, DEFAULT_ARTIFACT_MODE from apiserver.database.utils import get_options @@ -20,7 +20,9 @@ class ArtifactTypeData(models.Base): class Artifact(models.Base): key = StringField(required=True) type = StringField(required=True) - mode = StringField(validators=Enum("input", "output"), default="output") + mode = StringField( + validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE + ) uri = StringField() hash = StringField() content_size = IntField() @@ -112,12 +114,18 @@ class CloneRequest(TaskRequest): class AddOrUpdateArtifactsRequest(TaskRequest): - artifacts = ListField([Artifact], required=True) + artifacts = ListField([Artifact], validators=Length(minimum_value=1)) -class AddOrUpdateArtifactsResponse(models.Base): - added = ListField([str]) - updated = ListField([str]) +class ArtifactId(models.Base): + key = StringField(required=True) + mode = StringField( + validators=Enum(*get_options(ArtifactModes)), default=DEFAULT_ARTIFACT_MODE + ) + + +class DeleteArtifactsRequest(TaskRequest): + artifacts = ListField([ArtifactId], validators=Length(minimum_value=1)) class ResetRequest(UpdateRequest): diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py new file mode 100644 index 0000000..7197754 --- /dev/null +++ b/apiserver/bll/task/artifacts.py @@ -0,0 +1,85 @@ +from datetime import datetime +from hashlib import md5 +from operator import itemgetter +from typing import Sequence + +from apiserver.apimodels.tasks import Artifact as ApiArtifact, ArtifactId +from apiserver.bll.task.utils import get_task_for_update +from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact +from apiserver.timing_context import TimingContext +from apiserver.utilities.dicts import nested_get, nested_set + + +def get_artifact_id(artifact: dict): + """ + Calculate id from 'key' and 'mode' fields + Return hash on on the id so that it will not contain mongo illegal characters + """ + key_hash: str = md5(artifact["key"].encode()).hexdigest() + mode: str = artifact.get("mode", DEFAULT_ARTIFACT_MODE) + return f"{key_hash}_{mode}" + + +def artifacts_prepare_for_save(fields: dict): + artifacts_field = ("execution", "artifacts") + artifacts = nested_get(fields, artifacts_field) + if artifacts is None: + return + + nested_set( + fields, artifacts_field, value={get_artifact_id(a): a for a in artifacts} + ) + + +def artifacts_unprepare_from_saved(fields): + artifacts_field = ("execution", "artifacts") + artifacts = nested_get(fields, artifacts_field) + if artifacts is None: + return + + nested_set( + fields, + artifacts_field, + value=sorted(artifacts.values(), key=itemgetter("key", "mode")), + ) + + +class Artifacts: + @classmethod + def add_or_update_artifacts( + cls, company_id: str, task_id: str, artifacts: Sequence[ApiArtifact], + ) -> int: + with TimingContext("mongo", "update_artifacts"): + task = get_task_for_update( + company_id=company_id, task_id=task_id, allow_all_statuses=True + ) + + artifacts = { + get_artifact_id(a): Artifact(**a) + for a in (api_artifact.to_struct() for api_artifact in artifacts) + } + + update_cmds = { + f"set__execution__artifacts__{name}": value + for name, value in artifacts.items() + } + return task.update(**update_cmds, last_update=datetime.utcnow()) + + @classmethod + def delete_artifacts( + cls, company_id: str, task_id: str, artifact_ids: Sequence[ArtifactId] + ) -> int: + with TimingContext("mongo", "delete_artifacts"): + task = get_task_for_update( + company_id=company_id, task_id=task_id, allow_all_statuses=True + ) + + artifact_ids = [ + get_artifact_id(a) + for a in (artifact_id.to_struct() for artifact_id in artifact_ids) + ] + delete_cmds = { + f"unset__execution__artifacts__{id_}": 1 for id_ in set(artifact_ids) + } + + return task.update(**delete_cmds, last_update=datetime.utcnow()) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index b8e5fc2..bd96f1b 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -13,8 +13,10 @@ from apiserver.apimodels.tasks import ( Configuration, ) from apiserver.bll.task import TaskBLL +from apiserver.bll.task.utils import get_task_for_update from apiserver.config import config -from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus +from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem +from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper log = config.logger(__file__) @@ -58,32 +60,33 @@ class HyperParams: @classmethod def delete_params( - cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey] + cls, company_id: str, task_id: str, hyperparams: Sequence[HyperParamKey] ) -> int: - properties_only = cls._normalize_params(hyperparams) - task = cls._get_task_for_update( - company=company_id, id=task_id, allow_all_statuses=properties_only - ) + with TimingContext("mongo", "delete_hyperparams"): + properties_only = cls._normalize_params(hyperparams) + task = get_task_for_update( + company_id=company_id, task_id=task_id, allow_all_statuses=properties_only + ) - with_param, without_param = iterutils.partition( - hyperparams, key=lambda p: bool(p.name) - ) - sections_to_delete = {p.section for p in without_param} - delete_cmds = { - f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 - for section in sections_to_delete - } + with_param, without_param = iterutils.partition( + hyperparams, key=lambda p: bool(p.name) + ) + sections_to_delete = {p.section for p in without_param} + delete_cmds = { + f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1 + for section in sections_to_delete + } - for item in with_param: - section = ParameterKeyEscaper.escape(item.section) - if item.section in sections_to_delete: - raise errors.bad_request.FieldsConflict( - "Cannot delete section field if the whole section was scheduled for deletion" - ) - name = ParameterKeyEscaper.escape(item.name) - delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 + for item in with_param: + section = ParameterKeyEscaper.escape(item.section) + if item.section in sections_to_delete: + raise errors.bad_request.FieldsConflict( + "Cannot delete section field if the whole section was scheduled for deletion" + ) + name = ParameterKeyEscaper.escape(item.name) + delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1 - return task.update(**delete_cmds, last_update=datetime.utcnow()) + return task.update(**delete_cmds, last_update=datetime.utcnow()) @classmethod def edit_params( @@ -93,24 +96,25 @@ class HyperParams: hyperparams: Sequence[HyperParamItem], replace_hyperparams: str, ) -> int: - properties_only = cls._normalize_params(hyperparams) - task = cls._get_task_for_update( - company=company_id, id=task_id, allow_all_statuses=properties_only - ) + with TimingContext("mongo", "edit_hyperparams"): + properties_only = cls._normalize_params(hyperparams) + task = get_task_for_update( + company_id=company_id, task_id=task_id, allow_all_statuses=properties_only + ) - update_cmds = dict() - hyperparams = cls._db_dicts_from_list(hyperparams) - if replace_hyperparams == ReplaceHyperparams.all: - update_cmds["set__hyperparams"] = hyperparams - elif replace_hyperparams == ReplaceHyperparams.section: - for section, value in hyperparams.items(): - update_cmds[f"set__hyperparams__{section}"] = value - else: - for section, section_params in hyperparams.items(): - for name, value in section_params.items(): - update_cmds[f"set__hyperparams__{section}__{name}"] = value + update_cmds = dict() + hyperparams = cls._db_dicts_from_list(hyperparams) + if replace_hyperparams == ReplaceHyperparams.all: + update_cmds["set__hyperparams"] = hyperparams + elif replace_hyperparams == ReplaceHyperparams.section: + for section, value in hyperparams.items(): + update_cmds[f"set__hyperparams__{section}"] = value + else: + for section, section_params in hyperparams.items(): + for name, value in section_params.items(): + update_cmds[f"set__hyperparams__{section}__{name}"] = value - return task.update(**update_cmds, last_update=datetime.utcnow()) + return task.update(**update_cmds, last_update=datetime.utcnow()) @classmethod def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]: @@ -152,28 +156,29 @@ class HyperParams: def get_configuration_names( cls, company_id: str, task_ids: Sequence[str] ) -> Dict[str, list]: - pipeline = [ - { - "$match": { - "company": {"$in": [None, "", company_id]}, - "_id": {"$in": task_ids}, + with TimingContext("mongo", "get_configuration_names"): + pipeline = [ + { + "$match": { + "company": {"$in": [None, "", company_id]}, + "_id": {"$in": task_ids}, + } + }, + {"$project": {"items": {"$objectToArray": "$configuration"}}}, + {"$unwind": "$items"}, + {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, + ] + + tasks = Task.aggregate(pipeline) + + return { + task["_id"]: { + "names": sorted( + ParameterKeyEscaper.unescape(name) for name in task["names"] + ) } - }, - {"$project": {"items": {"$objectToArray": "$configuration"}}}, - {"$unwind": "$items"}, - {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, - ] - - tasks = Task.aggregate(pipeline) - - return { - task["_id"]: { - "names": sorted( - ParameterKeyEscaper.unescape(name) for name in task["names"] - ) + for task in tasks } - for task in tasks - } @classmethod def edit_configuration( @@ -183,47 +188,32 @@ class HyperParams: configuration: Sequence[Configuration], replace_configuration: bool, ) -> int: - task = cls._get_task_for_update(company=company_id, id=task_id) + with TimingContext("mongo", "edit_configuration"): + task = get_task_for_update(company_id=company_id, task_id=task_id) - update_cmds = dict() - configuration = { - ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) - for c in configuration - } - if replace_configuration: - update_cmds["set__configuration"] = configuration - else: - for name, value in configuration.items(): - update_cmds[f"set__configuration__{name}"] = value + update_cmds = dict() + configuration = { + ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct()) + for c in configuration + } + if replace_configuration: + update_cmds["set__configuration"] = configuration + else: + for name, value in configuration.items(): + update_cmds[f"set__configuration__{name}"] = value - return task.update(**update_cmds, last_update=datetime.utcnow()) + return task.update(**update_cmds, last_update=datetime.utcnow()) @classmethod def delete_configuration( - cls, company_id: str, task_id: str, configuration=Sequence[str] + cls, company_id: str, task_id: str, configuration: Sequence[str] ) -> int: - task = cls._get_task_for_update(company=company_id, id=task_id) + with TimingContext("mongo", "delete_configuration"): + task = get_task_for_update(company_id=company_id, task_id=task_id) - delete_cmds = { - f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 - for name in set(configuration) - } + delete_cmds = { + f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1 + for name in set(configuration) + } - return task.update(**delete_cmds, last_update=datetime.utcnow()) - - @staticmethod - def _get_task_for_update( - company: str, id: str, allow_all_statuses: bool = False - ) -> Task: - task = Task.get_for_writing(company=company, id=id, _only=("id", "status")) - if not task: - raise errors.bad_request.InvalidTaskId(id=id) - - if allow_all_statuses: - return task - - if task.status != TaskStatus.created: - raise errors.bad_request.InvalidTaskStatus( - expected=TaskStatus.created, status=task.status - ) - return task + return task.update(**delete_cmds, last_update=datetime.utcnow()) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index c244bc1..f8853e7 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -1,20 +1,14 @@ from collections import OrderedDict from datetime import datetime -from operator import attrgetter -from random import random -from time import sleep -from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict +from typing import Collection, Sequence, Tuple, Any, Optional, Dict import dpath -import pymongo.results import six from mongoengine import Q from six import string_types import apiserver.database.utils as dbutils -from apiserver.es_factory import es_factory from apiserver.apierrors import errors -from apiserver.apimodels.tasks import Artifact as ApiArtifact from apiserver.bll.organization import OrgBLL, Tags from apiserver.config import config from apiserver.database.errors import translate_errors_context @@ -28,14 +22,14 @@ from apiserver.database.model.task.task import ( TaskStatusMessage, TaskSystemTags, ArtifactModes, - Artifact, external_task_types, ) from apiserver.database.utils import get_company_or_none_constraint, id as create_id +from apiserver.es_factory import es_factory from apiserver.service_repo import APICall from apiserver.timing_context import TimingContext -from apiserver.utilities.dicts import deep_merge from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper +from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save from .utils import ChangeStatusRequest, validate_status_change @@ -181,9 +175,6 @@ class TaskBLL(object): execution_overrides: Optional[dict] = None, validate_references: bool = False, ) -> Task: - task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) - execution_dict = task.execution.to_proper_dict() if task.execution else {} - execution_model_overriden = False params_dict = { field: value for field, value in ( @@ -192,21 +183,32 @@ class TaskBLL(object): ) if value is not None } + + task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) + + execution_dict = task.execution.to_proper_dict() if task.execution else {} + execution_model_overriden = False if execution_overrides: + execution_model_overriden = execution_overrides.get("model") is not None + artifacts_prepare_for_save({"execution": execution_overrides}) + params_dict["execution"] = {} for legacy_param in ("parameters", "configuration"): legacy_value = execution_overrides.pop(legacy_param, None) if legacy_value is not None: params_dict["execution"] = legacy_value - execution_dict = deep_merge(execution_dict, execution_overrides) - execution_model_overriden = execution_overrides.get("model") is not None + + execution_dict.update(execution_overrides) + params_prepare_for_save(params_dict, previous_task=task) artifacts = execution_dict.get("artifacts") if artifacts: - execution_dict["artifacts"] = [ - a for a in artifacts if a.get("mode") != ArtifactModes.output - ] + execution_dict["artifacts"] = { + k: a + for k, a in artifacts.items() + if a.get("mode") != ArtifactModes.output + } now = datetime.utcnow() with translate_errors_context(): @@ -543,97 +545,6 @@ class TaskBLL(object): force=force, ).execute() - @classmethod - def add_or_update_artifacts( - cls, task_id: str, company_id: str, artifacts: List[ApiArtifact] - ) -> Tuple[List[str], List[str]]: - key = attrgetter("key", "mode") - - if not artifacts: - return [], [] - - with translate_errors_context(), TimingContext("mongo", "update_artifacts"): - artifacts: List[Artifact] = [ - Artifact(**artifact.to_struct()) for artifact in artifacts - ] - - attempts = int(config.get("services.tasks.artifacts.update_attempts", 10)) - - for retry in range(attempts): - task = cls.get_task_with_access( - task_id, company_id=company_id, requires_write_access=True - ) - - current = list(map(key, task.execution.artifacts)) - updated = [a for a in artifacts if key(a) in current] - added = [a for a in artifacts if a not in updated] - - filter = {"_id": task_id, "company": company_id} - update = {} - array_filters = None - if current: - filter["execution.artifacts"] = { - "$size": len(current), - "$all": [ - *( - {"$elemMatch": {"key": key, "mode": mode}} - for key, mode in current - ) - ], - } - else: - filter["$or"] = [ - {"execution.artifacts": {"$exists": False}}, - {"execution.artifacts": {"$size": 0}}, - ] - - if added: - update["$push"] = { - "execution.artifacts": {"$each": [a.to_mongo() for a in added]} - } - if updated: - update["$set"] = { - f"execution.artifacts.$[artifact{index}]": artifact.to_mongo() - for index, artifact in enumerate(updated) - } - array_filters = [ - { - f"artifact{index}.key": artifact.key, - f"artifact{index}.mode": artifact.mode, - } - for index, artifact in enumerate(updated) - ] - - if not update: - return [], [] - - result: pymongo.results.UpdateResult = Task._get_collection().update_one( - filter=filter, - update=update, - array_filters=array_filters, - upsert=False, - ) - - if result.matched_count >= 1: - break - - wait_msec = random() * int( - config.get("services.tasks.artifacts.update_retry_msec", 500) - ) - - log.warning( - f"Failed to update artifacts for task {task_id} (updated by another party)," - f" retrying {retry+1}/{attempts} in {wait_msec}ms" - ) - - sleep(wait_msec / 1000) - else: - raise errors.server_error.UpdateFailed( - "task artifacts updated by another party" - ) - - return [a.key for a in added], [a.key for a in updated] - @staticmethod def get_aggregated_project_parameters( company_id, diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index fb3464f..6e5cf31 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -171,3 +171,23 @@ def split_by( [item for cond, item in applied if cond], [item for cond, item in applied if not cond], ) + + +def get_task_for_update( + company_id: str, task_id: str, allow_all_statuses: bool = False +) -> Task: + """ + Loads only task id and return the task only if it is updatable (status == 'created') + """ + task = Task.get_for_writing(company=company_id, id=task_id, _only=("id", "status")) + if not task: + raise errors.bad_request.InvalidTaskId(id=task_id) + + if allow_all_statuses: + return task + + if task.status != TaskStatus.created: + raise errors.bad_request.InvalidTaskStatus( + expected=TaskStatus.created, status=task.status + ) + return task diff --git a/apiserver/config/default/services/tasks.conf b/apiserver/config/default/services/tasks.conf index 3650344..1ecc314 100644 --- a/apiserver/config/default/services/tasks.conf +++ b/apiserver/config/default/services/tasks.conf @@ -8,9 +8,4 @@ non_responsive_tasks_watchdog { watch_interval_sec: 900 } -artifacts { - update_attempts: 10 - update_retry_msec: 500 -} - multi_task_histogram_limit: 100 diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 63f66d5..0ec3142 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -1,3 +1,5 @@ +from typing import Dict + from mongoengine import ( StringField, EmbeddedDocumentField, @@ -14,7 +16,6 @@ from apiserver.database.fields import ( SafeMapField, SafeDictField, UnionField, - EmbeddedDocumentSortedListField, SafeSortedListField, ) from apiserver.database.model import AttributedDocument @@ -72,10 +73,13 @@ class ArtifactModes: output = "output" +DEFAULT_ARTIFACT_MODE = ArtifactModes.output + + class Artifact(EmbeddedDocument): key = StringField(required=True) type = StringField(required=True) - mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output) + mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE) uri = StringField() hash = StringField() content_size = LongField() @@ -107,7 +111,7 @@ class Execution(EmbeddedDocument, ProperDictMixin): model_desc = SafeMapField(StringField(default="")) model_labels = ModelLabels() framework = StringField() - artifacts = EmbeddedDocumentSortedListField(Artifact) + artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact)) docker_cmd = StringField() queue = StringField() """ Queue ID where task was queued """ diff --git a/apiserver/mongo/migrations/0.17.0.py b/apiserver/mongo/migrations/0.17.0.py new file mode 100644 index 0000000..28d977d --- /dev/null +++ b/apiserver/mongo/migrations/0.17.0.py @@ -0,0 +1,19 @@ +from pymongo.database import Database, Collection + +from apiserver.bll.task.artifacts import get_artifact_id +from apiserver.utilities.dicts import nested_get + + +def migrate_backend(db: Database): + collection: Collection = db["task"] + artifacts_field = "execution.artifacts" + query = {artifacts_field: {"$type": 4}} + for doc in collection.find(filter=query, projection=(artifacts_field,)): + artifacts = nested_get(doc, artifacts_field.split(".")) + if not isinstance(artifacts, list): + continue + + new_artifacts = {get_artifact_id(a): a for a in artifacts} + collection.update_one( + {"_id": doc["_id"]}, {"$set": {artifacts_field: new_artifacts}} + ) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index d0ce951..9ae47c7 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -137,6 +137,14 @@ _definitions { } } } + artifact_mode_enum { + type: string + enum: [ + input + output + ] + default: output + } artifact { type: object required: [key, type] @@ -151,12 +159,7 @@ _definitions { } mode { description: "System defined input/output indication" - type: string - enum: [ - input - output - ] - default: output + "$ref": "#/definitions/artifact_mode_enum" } uri { description: "Raw data location" @@ -190,6 +193,20 @@ _definitions { } } } + artifact_id { + type: object + required: [key] + properties { + key { + description: "Entry key" + type: string + } + mode { + description: "System defined input/output indication" + "$ref": "#/definitions/artifact_mode_enum" + } + } + } execution { type: object properties { @@ -1552,11 +1569,11 @@ ping { } add_or_update_artifacts { - "2.6" { - description: """ Update an existing artifact (search by key/mode) or add a new one """ + "2.10" { + description: """Update existing artifacts (search by key/mode) and add new ones""" request { type: object - required: [ task, artifacts ] + required: [task, artifacts] properties { task { description: "Task ID" @@ -1565,22 +1582,45 @@ add_or_update_artifacts { artifacts { description: "Artifacts to add or update" type: array - items { "$ref": "#/definitions/artifact" } + items {"$ref": "#/definitions/artifact"} } } } response { type: object properties { - added { - description: "Keys of artifacts added" - type: array - items { type: string } - } updated { - description: "Keys of artifacts updated" + description: "Indicates if the task was updated successfully" + type: integer + } + } + } + } +} +delete_artifacts { + "2.10" { + description: """Delete existing artifacts (search by key/mode)""" + request { + type: object + required: [task, artifacts] + properties { + task { + description: "Task ID" + type: string + } + artifacts { + description: "Artifacts to delete" type: array - items { type: string } + items {"$ref": "#/definitions/artifact_id"} + } + } + } + response { + type: object + properties { + deleted { + description: "Indicates if the task was updated successfully" + type: integer } } } diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index c3f1c8a..b311f91 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -29,7 +29,6 @@ from apiserver.apimodels.tasks import ( DequeueResponse, CloneRequest, AddOrUpdateArtifactsRequest, - AddOrUpdateArtifactsResponse, GetTypesRequest, ResetRequest, GetHyperParamsRequest, @@ -39,6 +38,7 @@ from apiserver.apimodels.tasks import ( EditConfigurationRequest, DeleteConfigurationRequest, GetConfigurationNamesRequest, + DeleteArtifactsRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.organization import OrgBLL, Tags @@ -49,6 +49,11 @@ from apiserver.bll.task import ( update_project_time, split_by, ) +from apiserver.bll.task.artifacts import ( + artifacts_prepare_for_save, + artifacts_unprepare_from_saved, + Artifacts, +) from apiserver.bll.task.hyperparams import HyperParams from apiserver.bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog from apiserver.bll.task.param_utils import ( @@ -279,6 +284,7 @@ create_fields = { def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): conform_tag_fields(call, fields, validate=True) params_prepare_for_save(fields, previous_task=previous_task) + artifacts_prepare_for_save(fields) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_fields: @@ -301,10 +307,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]) conform_output_tags(call, tasks_data) for data in tasks_data: + need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9") params_unprepare_from_saved( - fields=data, - copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"), + fields=data, copy_to_legacy=need_legacy_params, ) + artifacts_unprepare_from_saved(fields=data) def prepare_create_fields( @@ -846,10 +853,15 @@ def reset(call: APICall, company_id, request: ResetRequest): set__execution=Execution(), unset__script=1, ) else: - updates.update( - unset__execution__queue=1, - __raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}}, - ) + updates.update(unset__execution__queue=1) + if task.execution and task.execution.artifacts: + updates.update( + set__execution__artifacts={ + key: artifact + for key, artifact in task.execution.artifacts + if artifact.get("mode") == "input" + } + ) res = ResetResponse( **ChangeStatusRequest( @@ -1090,17 +1102,34 @@ def ping(_, company_id, request: PingRequest): @endpoint( "tasks.add_or_update_artifacts", - min_version="2.6", + min_version="2.10", request_data_model=AddOrUpdateArtifactsRequest, - response_data_model=AddOrUpdateArtifactsResponse, ) def add_or_update_artifacts( call: APICall, company_id, request: AddOrUpdateArtifactsRequest ): - added, updated = TaskBLL.add_or_update_artifacts( - task_id=request.task, company_id=company_id, artifacts=request.artifacts - ) - call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated) + with translate_errors_context(): + call.result.data = { + "updated": Artifacts.add_or_update_artifacts( + company_id=company_id, task_id=request.task, artifacts=request.artifacts + ) + } + + +@endpoint( + "tasks.delete_artifacts", + min_version="2.10", + request_data_model=DeleteArtifactsRequest, +) +def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest): + with translate_errors_context(): + call.result.data = { + "deleted": Artifacts.delete_artifacts( + company_id=company_id, + task_id=request.task, + artifact_ids=request.artifacts, + ) + } @endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py new file mode 100644 index 0000000..7dc0e7d --- /dev/null +++ b/apiserver/tests/automated/test_task_artifacts.py @@ -0,0 +1,100 @@ +from operator import itemgetter +from typing import Sequence + +from apiserver.tests.automated import TestService + + +class TestTasksArtifacts(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.10") + + def new_task(self, **kwargs) -> str: + self.update_missing( + kwargs, + type="testing", + name="test artifacts", + input=dict(view=dict()), + delete_params=dict(force=True), + ) + return self.create_temp("tasks", **kwargs) + + def test_artifacts_set_get(self): + artifacts = [ + dict(key="a", type="str", uri="test1"), + dict(key="b", type="int", uri="test2"), + ] + + # test create/get and get_all + task = self.new_task(execution={"artifacts": artifacts}) + res = self.api.tasks.get_by_id(task=task).task + self._assertTaskArtifacts(artifacts, res) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts, res) + + # test edit + artifacts = [ + dict(key="bb", type="str", uri="test1", mode="output"), + dict(key="aa", type="int", uri="test2", mode="input"), + ] + self.api.tasks.edit(task=task, execution={"artifacts": artifacts}) + res = self.api.tasks.get_by_id(task=task).task + self._assertTaskArtifacts(artifacts, res) + + # test clone + task2 = self.api.tasks.clone(task=task).id + res = self.api.tasks.get_by_id(task=task2).task + self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res) + + new_artifacts = [ + dict(key="x", type="str", uri="x_test"), + dict(key="y", type="int", uri="y_test"), + dict(key="z", type="int", uri="y_test"), + ] + new_task = self.api.tasks.clone( + task=task, execution_overrides={"artifacts": new_artifacts} + ).id + res = self.api.tasks.get_by_id(task=new_task).task + self._assertTaskArtifacts(new_artifacts, res) + + def test_artifacts_edit_delete(self): + artifacts = [ + dict(key="a", type="str", uri="test1"), + dict(key="b", type="int", uri="test2"), + ] + task = self.new_task(execution={"artifacts": artifacts}) + + # test add_or_update + edit = [ + dict(key="a", type="str", uri="hello"), + dict(key="c", type="int", uri="world"), + ] + res = self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit) + artifacts = self._update_source(artifacts, edit) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts, res) + + # test delete + self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}]) + res = self.api.tasks.get_all_ex(id=[task]).tasks[0] + self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res) + + def _update_source(self, source: Sequence[dict], update: Sequence[dict]): + dict1 = {s["key"]: s for s in source} + dict2 = {u["key"]: u for u in update} + res = { + k: v if k not in dict2 else dict2[k] + for k, v in dict1.items() + } + res.update({k: v for k, v in dict2.items() if k not in res}) + return list(res.values()) + + def _assertTaskArtifacts(self, artifacts: Sequence[dict], task): + task_artifacts: dict = task.execution.artifacts + self.assertEqual(len(artifacts), len(task_artifacts)) + + for expected, actual in zip( + sorted(artifacts, key=itemgetter("key", "type")), task_artifacts + ): + self.assertEqual( + expected, {k: v for k, v in actual.items() if k in expected} + ) diff --git a/apiserver/utilities/dicts.py b/apiserver/utilities/dicts.py index 3790063..b6c936b 100644 --- a/apiserver/utilities/dicts.py +++ b/apiserver/utilities/dicts.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, Any +from typing import Sequence, Tuple, Any, Union, Callable, Optional def flatten_nested_items( @@ -33,3 +33,52 @@ def deep_merge(source: dict, override: dict) -> dict: source[key] = value return source + + +def nested_get( + dictionary: dict, + path: Union[Sequence[str], str], + default: Optional[Union[Any, Callable]] = None, +) -> Any: + if isinstance(path, str): + path = [path] + + node = dictionary + for key in path: + if key not in node: + if callable(default): + return default() + return default + node = node.get(key) + + return node + + +def nested_delete(dictionary: dict, path: Union[Sequence[str], str]) -> bool: + """ + Return 'True' if the element was deleted + """ + if isinstance(path, str): + path = [path] + + *parent_path, last_key = path + parent = nested_get(dictionary, parent_path) + if not parent or last_key not in parent: + return False + + del parent[last_key] + return True + + +def nested_set(dictionary: dict, path: Union[Sequence[str], str], value: Any): + if isinstance(path, str): + path = [path] + + *parent_path, last_key = path + node = dictionary + for key in parent_path: + if key not in node: + node[key] = {} + node = node.get(key) + + node[last_key] = value