From 1e4756aa1d4b2335903dbe5c2bde9881db98025c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Dec 2019 17:57:26 +0200 Subject: [PATCH] Add support for atomic add/update of task artifacts --- server/apierrors/__init__.py | 1 + server/bll/task/task_bll.py | 107 ++++++++++++++++++++-- server/config/default/services/tasks.conf | 5 + server/database/model/task/task.py | 3 +- server/schema/services/tasks.conf | 36 ++++++++ server/service_repo/service_repo.py | 2 +- server/services/tasks.py | 18 ++++ 7 files changed, 164 insertions(+), 8 deletions(-) diff --git a/server/apierrors/__init__.py b/server/apierrors/__init__.py index 64af157..8064181 100644 --- a/server/apierrors/__init__.py +++ b/server/apierrors/__init__.py @@ -121,6 +121,7 @@ _error_codes = { 100: ('data_error', 'general data error'), 101: ('inconsistent_data', 'inconsistent data encountered in document'), 102: ('database_unavailable', 'database is temporarily unavailable'), + 110: ('update_failed', 'update failed'), # Index-related issues 201: ('missing_index', 'missing internal index'), diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index cc4790e..921946b 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -1,15 +1,19 @@ import re from collections import OrderedDict from datetime import datetime, timedelta +from operator import attrgetter +from random import random from time import sleep -from typing import Collection, Sequence, Tuple, Any +from typing import Collection, Sequence, Tuple, Any, Optional, List +import pymongo.results import six from mongoengine import Q from six import string_types import es_factory from apierrors import errors +from apimodels.tasks import Artifact as ApiArtifact from config import config from database.errors import translate_errors_context from database.model.model import Model @@ -28,6 +32,9 @@ from utilities.threads_manager import ThreadsManager from .utils import ChangeStatusRequest, validate_status_change +log = config.logger(__file__) + + class TaskBLL(object): threads = ThreadsManager("TaskBLL") @@ -373,7 +380,7 @@ class TaskBLL(object): :return: updated task fields """ - task = TaskBLL.get_task_with_access( + task = cls.get_task_with_access( task_id, company_id=company_id, only=( @@ -411,6 +418,97 @@ class TaskBLL(object): force=force, ).execute() + @classmethod + def add_or_update_artifacts( + cls, task_id: str, company_id: str, artifacts: List[ApiArtifact] + ) -> Tuple[List[str], List[str]]: + key = attrgetter("key", "mode") + + if not artifacts: + return [], [] + + with translate_errors_context(), TimingContext("mongo", "update_artifacts"): + artifacts: List[Artifact] = [ + Artifact(**artifact.to_struct()) for artifact in artifacts + ] + + attempts = int(config.get("services.tasks.artifacts.update_attempts", 10)) + + for retry in range(attempts): + task = cls.get_task_with_access( + task_id, company_id=company_id, requires_write_access=True + ) + + current = list(map(key, task.execution.artifacts)) + updated = [a for a in artifacts if key(a) in current] + added = [a for a in artifacts if a not in updated] + + filter = {"_id": task_id, "company": company_id} + update = {} + array_filters = None + if current: + filter["execution.artifacts"] = { + "$size": len(current), + "$all": [ + *( + {"$elemMatch": {"key": key, "mode": mode}} + for key, mode in current + ) + ], + } + else: + filter["$or"] = [ + {"execution.artifacts": {"$exists": False}}, + {"execution.artifacts": {"$size": 0}}, + ] + + if added: + update["$push"] = { + "execution.artifacts": {"$each": [a.to_mongo() for a in added]} + } + if updated: + update["$set"] = { + f"execution.artifacts.$[artifact{index}]": artifact.to_mongo() + for index, artifact in enumerate(updated) + } + array_filters = [ + { + f"artifact{index}.key": artifact.key, + f"artifact{index}.mode": artifact.mode, + } + for index, artifact in enumerate(updated) + ] + + if not update: + return [], [] + + result: pymongo.results.UpdateResult = Task._get_collection().update_one( + filter=filter, + update=update, + array_filters=array_filters, + upsert=False, + ) + + if result.matched_count >= 1: + break + + wait_msec = random() * int( + config.get("services.tasks.artifacts.update_retry_msec", 500) + ) + + log.warning( + f"Failed to update artifacts for task {task_id} (updated by another party)," + f" retrying {retry+1}/{attempts} in {wait_msec}ms" + ) + + sleep(wait_msec / 1000) + else: + raise errors.server_error.UpdateFailed( + "task artifacts updated by another party" + ) + + return [a.key for a in added], [a.key for a in updated] + @classmethod @threads.register("non_responsive_tasks_watchdog", daemon=True) def start_non_responsive_tasks_watchdog(cls): @@ -502,10 +600,7 @@ class TaskBLL(object): ] with translate_errors_context(): - result = next( - Task.aggregate(*pipeline), - None, - ) + result = next(Task.aggregate(*pipeline), None) total = 0 remaining = 0 diff --git a/server/config/default/services/tasks.conf b/server/config/default/services/tasks.conf index 9d25c1d..2e25c83 100644 --- a/server/config/default/services/tasks.conf +++ b/server/config/default/services/tasks.conf @@ -5,3 +5,8 @@ non_responsive_tasks_watchdog { # Watchdog will sleep for this number of seconds after each cycle watch_interval_sec: 900 } + +artifacts { + update_attempts: 10 + update_retry_msec: 500 +} \ No newline at end of file diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index d8854db..5af3440 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -18,6 +18,7 @@ from database.fields import ( SafeSortedListField, ) from database.model import AttributedDocument +from database.model.base import ProperDictMixin from database.model.model_labels import ModelLabels from database.model.project import Project from database.utils import get_options @@ -78,7 +79,7 @@ class Artifact(EmbeddedDocument): display_data = SafeSortedListField(ListField(UnionField((int, float, str)))) -class Execution(EmbeddedDocument): +class Execution(EmbeddedDocument, ProperDictMixin): test_split = IntField(default=0) parameters = SafeDictField(default=dict) model = StringField(reference_field="Model") diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index 71f3c72..2b1682a 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -1304,4 +1304,40 @@ ping { additionalProperties: false } } +} + +add_or_update_artifacts { + "2.6" { + description: """ Update an existing artifact (search by key/mode) or add a new one """ + request { + type: object + required: [ task, artifacts ] + properties { + task { + description: "Task ID" + type: string + } + artifacts { + description: "Artifacts to add or update" + type: array + items { "$ref": "#/definitions/artifact" } + } + } + } + response { + type: object + properties { + added { + description: "Keys of artifacts added" + type: array + items { type: string } + } + updated { + description: "Keys of artifacts updated" + type: array + items { type: string } + } + } + } + } } \ No newline at end of file diff --git a/server/service_repo/service_repo.py b/server/service_repo/service_repo.py index 593ec8e..78aabe4 100644 --- a/server/service_repo/service_repo.py +++ b/server/service_repo/service_repo.py @@ -34,7 +34,7 @@ class ServiceRepo(object): """If the check is set, parsing will fail for endpoint request with the version that is grater than the current maximum """ - _max_version = PartialVersion("2.4") + _max_version = PartialVersion("2.6") """ Maximum version number (the highest min_version value across all endpoints) """ _endpoint_exp = ( diff --git a/server/services/tasks.py b/server/services/tasks.py index ddaa21f..f0418be 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -27,6 +27,9 @@ from apimodels.tasks import ( EnqueueRequest, EnqueueResponse, DequeueResponse, + CloneRequest, + AddOrUpdateArtifactsRequest, + AddOrUpdateArtifactsResponse, ) from bll.event import EventBLL from bll.queue import QueueBLL @@ -837,3 +840,18 @@ def ping(_, company_id, request: PingRequest): TaskBLL.set_last_update( task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow() ) + + +@endpoint( + "tasks.add_or_update_artifacts", + min_version="2.6", + request_data_model=AddOrUpdateArtifactsRequest, + response_data_model=AddOrUpdateArtifactsResponse, +) +def add_or_update_artifacts( + call: APICall, company_id, request: AddOrUpdateArtifactsRequest +): + added, updated = TaskBLL.add_or_update_artifacts( + task_id=request.task, company_id=company_id, artifacts=request.artifacts + ) + call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)