Add support for atomic add/update of task artifacts

This commit is contained in:
allegroai
2019-12-24 17:57:26 +02:00
parent 52529d3c55
commit 1e4756aa1d
7 changed files with 164 additions and 8 deletions

View File

@@ -1,15 +1,19 @@
import re
from collections import OrderedDict
from datetime import datetime, timedelta
from operator import attrgetter
from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any
from typing import Collection, Sequence, Tuple, Any, Optional, List
import pymongo.results
import six
from mongoengine import Q
from six import string_types
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@@ -28,6 +32,9 @@ from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
class TaskBLL(object):
threads = ThreadsManager("TaskBLL")
@@ -373,7 +380,7 @@ class TaskBLL(object):
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task = cls.get_task_with_access(
task_id,
company_id=company_id,
only=(
@@ -411,6 +418,97 @@ class TaskBLL(object):
force=force,
).execute()
@classmethod
def add_or_update_artifacts(
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
) -> Tuple[List[str], List[str]]:
key = attrgetter("key", "mode")
if not artifacts:
return [], []
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
artifacts: List[Artifact] = [
Artifact(**artifact.to_struct()) for artifact in artifacts
]
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
for retry in range(attempts):
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
current = list(map(key, task.execution.artifacts))
updated = [a for a in artifacts if key(a) in current]
added = [a for a in artifacts if a not in updated]
filter = {"_id": task_id, "company": company_id}
update = {}
array_filters = None
if current:
filter["execution.artifacts"] = {
"$size": len(current),
"$all": [
*(
{"$elemMatch": {"key": key, "mode": mode}}
for key, mode in current
)
],
}
else:
filter["$or"] = [
{"execution.artifacts": {"$exists": False}},
{"execution.artifacts": {"$size": 0}},
]
if added:
update["$push"] = {
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
}
if updated:
update["$set"] = {
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
for index, artifact in enumerate(updated)
}
array_filters = [
{
f"artifact{index}.key": artifact.key,
f"artifact{index}.mode": artifact.mode,
}
for index, artifact in enumerate(updated)
]
if not update:
return [], []
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
filter=filter,
update=update,
array_filters=array_filters,
upsert=False,
)
if result.matched_count >= 1:
break
wait_msec = random() * int(
config.get("services.tasks.artifacts.update_retry_msec", 500)
)
log.warning(
f"Failed to update artifacts for task {task_id} (updated by another party),"
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
)
sleep(wait_msec / 1000)
else:
raise errors.server_error.UpdateFailed(
"task artifacts updated by another party"
)
return [a.key for a in added], [a.key for a in updated]
@classmethod
@threads.register("non_responsive_tasks_watchdog", daemon=True)
def start_non_responsive_tasks_watchdog(cls):
@@ -502,10 +600,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(
Task.aggregate(*pipeline),
None,
)
result = next(Task.aggregate(*pipeline), None)
total = 0
remaining = 0