Add support for atomic add/update of task artifacts

This commit is contained in:
allegroai 2019-12-24 17:57:26 +02:00
parent 52529d3c55
commit 1e4756aa1d
7 changed files with 164 additions and 8 deletions

View File

@ -121,6 +121,7 @@ _error_codes = {
100: ('data_error', 'general data error'),
101: ('inconsistent_data', 'inconsistent data encountered in document'),
102: ('database_unavailable', 'database is temporarily unavailable'),
110: ('update_failed', 'update failed'),
# Index-related issues
201: ('missing_index', 'missing internal index'),

View File

@ -1,15 +1,19 @@
import re
from collections import OrderedDict
from datetime import datetime, timedelta
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any
from typing import Collection, Sequence, Tuple, Any, Optional, List
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@ -28,6 +32,9 @@ from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
class TaskBLL(object):
threads = ThreadsManager("TaskBLL")
@ -373,7 +380,7 @@ class TaskBLL(object):
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
@ -411,6 +418,97 @@ class TaskBLL(object):
force=force,
).execute()
@classmethod
def add_or_update_artifacts(
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
) -> Tuple[List[str], List[str]]:
key = attrgetter("key", "mode")
if not artifacts:
return [], []
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
artifacts: List[Artifact] = [
Artifact(**artifact.to_struct()) for artifact in artifacts
]
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
for retry in range(attempts):
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
current = list(map(key, task.execution.artifacts))
updated = [a for a in artifacts if key(a) in current]
added = [a for a in artifacts if a not in updated]
filter = {"_id": task_id, "company": company_id}
update = {}
array_filters = None
if current:
filter["execution.artifacts"] = {
"$size": len(current),
"$all": [
*(
{"$elemMatch": {"key": key, "mode": mode}}
for key, mode in current
)
],
}
else:
filter["$or"] = [
{"execution.artifacts": {"$exists": False}},
{"execution.artifacts": {"$size": 0}},
]
if added:
update["$push"] = {
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
}
if updated:
update["$set"] = {
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
for index, artifact in enumerate(updated)
}
array_filters = [
{
f"artifact{index}.key": artifact.key,
f"artifact{index}.mode": artifact.mode,
}
for index, artifact in enumerate(updated)
]
if not update:
return [], []
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
filter=filter,
update=update,
array_filters=array_filters,
upsert=False,
)
if result.matched_count >= 1:
break
wait_msec = random() * int(
config.get("services.tasks.artifacts.update_retry_msec", 500)
)
log.warning(
f"Failed to update artifacts for task {task_id} (updated by another party),"
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
)
sleep(wait_msec / 1000)
else:
raise errors.server_error.UpdateFailed(
"task artifacts updated by another party"
)
return [a.key for a in added], [a.key for a in updated]
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start_non_responsive_tasks_watchdog(cls):
@ -502,10 +600,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(
Task.aggregate(*pipeline),
None,
)
result = next(Task.aggregate(*pipeline), None)
total = 0
remaining = 0

View File

@ -5,3 +5,8 @@ non_responsive_tasks_watchdog {
# Watchdog will sleep for this number of seconds after each cycle
watch_interval_sec: 900
}
artifacts {
update_attempts: 10
update_retry_msec: 500
}

View File

@ -18,6 +18,7 @@ from database.fields import (
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.base import ProperDictMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
@ -78,7 +79,7 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class Execution(EmbeddedDocument):
class Execution(EmbeddedDocument, ProperDictMixin):
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")

View File

@ -1304,4 +1304,40 @@ ping {
additionalProperties: false
}
}
}
add_or_update_artifacts {
"2.6" {
description: """ Update an existing artifact (search by key/mode) or add a new one """
request {
type: object
required: [ task, artifacts ]
properties {
task {
description: "Task ID"
type: string
}
artifacts {
description: "Artifacts to add or update"
type: array
items { "$ref": "#/definitions/artifact" }
}
}
}
response {
type: object
properties {
added {
description: "Keys of artifacts added"
type: array
items { type: string }
}
updated {
description: "Keys of artifacts updated"
type: array
items { type: string }
}
}
}
}
}

View File

@ -34,7 +34,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.4")
_max_version = PartialVersion("2.6")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@ -27,6 +27,9 @@ from apimodels.tasks import (
EnqueueRequest,
EnqueueResponse,
DequeueResponse,
CloneRequest,
AddOrUpdateArtifactsRequest,
AddOrUpdateArtifactsResponse,
)
from bll.event import EventBLL
from bll.queue import QueueBLL
@ -837,3 +840,18 @@ def ping(_, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
)
@endpoint(
"tasks.add_or_update_artifacts",
min_version="2.6",
request_data_model=AddOrUpdateArtifactsRequest,
response_data_model=AddOrUpdateArtifactsResponse,
)
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
added, updated = TaskBLL.add_or_update_artifacts(
task_id=request.task, company_id=company_id, artifacts=request.artifacts
)
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)