mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add support for atomic add/update of task artifacts
This commit is contained in:
parent
52529d3c55
commit
1e4756aa1d
@ -121,6 +121,7 @@ _error_codes = {
|
|||||||
100: ('data_error', 'general data error'),
|
100: ('data_error', 'general data error'),
|
||||||
101: ('inconsistent_data', 'inconsistent data encountered in document'),
|
101: ('inconsistent_data', 'inconsistent data encountered in document'),
|
||||||
102: ('database_unavailable', 'database is temporarily unavailable'),
|
102: ('database_unavailable', 'database is temporarily unavailable'),
|
||||||
|
110: ('update_failed', 'update failed'),
|
||||||
|
|
||||||
# Index-related issues
|
# Index-related issues
|
||||||
201: ('missing_index', 'missing internal index'),
|
201: ('missing_index', 'missing internal index'),
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
import re
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from operator import attrgetter
|
||||||
|
from random import random
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Collection, Sequence, Tuple, Any
|
from typing import Collection, Sequence, Tuple, Any, Optional, List
|
||||||
|
|
||||||
|
import pymongo.results
|
||||||
import six
|
import six
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
import es_factory
|
import es_factory
|
||||||
from apierrors import errors
|
from apierrors import errors
|
||||||
|
from apimodels.tasks import Artifact as ApiArtifact
|
||||||
from config import config
|
from config import config
|
||||||
from database.errors import translate_errors_context
|
from database.errors import translate_errors_context
|
||||||
from database.model.model import Model
|
from database.model.model import Model
|
||||||
@ -28,6 +32,9 @@ from utilities.threads_manager import ThreadsManager
|
|||||||
from .utils import ChangeStatusRequest, validate_status_change
|
from .utils import ChangeStatusRequest, validate_status_change
|
||||||
|
|
||||||
|
|
||||||
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class TaskBLL(object):
|
class TaskBLL(object):
|
||||||
threads = ThreadsManager("TaskBLL")
|
threads = ThreadsManager("TaskBLL")
|
||||||
|
|
||||||
@ -373,7 +380,7 @@ class TaskBLL(object):
|
|||||||
:return: updated task fields
|
:return: updated task fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task = TaskBLL.get_task_with_access(
|
task = cls.get_task_with_access(
|
||||||
task_id,
|
task_id,
|
||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
only=(
|
only=(
|
||||||
@ -411,6 +418,97 @@ class TaskBLL(object):
|
|||||||
force=force,
|
force=force,
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_or_update_artifacts(
|
||||||
|
cls, task_id: str, company_id: str, artifacts: List[ApiArtifact]
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
key = attrgetter("key", "mode")
|
||||||
|
|
||||||
|
if not artifacts:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
with translate_errors_context(), TimingContext("mongo", "update_artifacts"):
|
||||||
|
artifacts: List[Artifact] = [
|
||||||
|
Artifact(**artifact.to_struct()) for artifact in artifacts
|
||||||
|
]
|
||||||
|
|
||||||
|
attempts = int(config.get("services.tasks.artifacts.update_attempts", 10))
|
||||||
|
|
||||||
|
for retry in range(attempts):
|
||||||
|
task = cls.get_task_with_access(
|
||||||
|
task_id, company_id=company_id, requires_write_access=True
|
||||||
|
)
|
||||||
|
|
||||||
|
current = list(map(key, task.execution.artifacts))
|
||||||
|
updated = [a for a in artifacts if key(a) in current]
|
||||||
|
added = [a for a in artifacts if a not in updated]
|
||||||
|
|
||||||
|
filter = {"_id": task_id, "company": company_id}
|
||||||
|
update = {}
|
||||||
|
array_filters = None
|
||||||
|
if current:
|
||||||
|
filter["execution.artifacts"] = {
|
||||||
|
"$size": len(current),
|
||||||
|
"$all": [
|
||||||
|
*(
|
||||||
|
{"$elemMatch": {"key": key, "mode": mode}}
|
||||||
|
for key, mode in current
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
filter["$or"] = [
|
||||||
|
{"execution.artifacts": {"$exists": False}},
|
||||||
|
{"execution.artifacts": {"$size": 0}},
|
||||||
|
]
|
||||||
|
|
||||||
|
if added:
|
||||||
|
update["$push"] = {
|
||||||
|
"execution.artifacts": {"$each": [a.to_mongo() for a in added]}
|
||||||
|
}
|
||||||
|
if updated:
|
||||||
|
update["$set"] = {
|
||||||
|
f"execution.artifacts.$[artifact{index}]": artifact.to_mongo()
|
||||||
|
for index, artifact in enumerate(updated)
|
||||||
|
}
|
||||||
|
array_filters = [
|
||||||
|
{
|
||||||
|
f"artifact{index}.key": artifact.key,
|
||||||
|
f"artifact{index}.mode": artifact.mode,
|
||||||
|
}
|
||||||
|
for index, artifact in enumerate(updated)
|
||||||
|
]
|
||||||
|
|
||||||
|
if not update:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
result: pymongo.results.UpdateResult = Task._get_collection().update_one(
|
||||||
|
filter=filter,
|
||||||
|
update=update,
|
||||||
|
array_filters=array_filters,
|
||||||
|
upsert=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.matched_count >= 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
wait_msec = random() * int(
|
||||||
|
config.get("services.tasks.artifacts.update_retry_msec", 500)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.warning(
|
||||||
|
f"Failed to update artifacts for task {task_id} (updated by another party),"
|
||||||
|
f" retrying {retry+1}/{attempts} in {wait_msec}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
sleep(wait_msec / 1000)
|
||||||
|
else:
|
||||||
|
raise errors.server_error.UpdateFailed(
|
||||||
|
"task artifacts updated by another party"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [a.key for a in added], [a.key for a in updated]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
@threads.register("non_responsive_tasks_watchdog", daemon=True)
|
||||||
def start_non_responsive_tasks_watchdog(cls):
|
def start_non_responsive_tasks_watchdog(cls):
|
||||||
@ -502,10 +600,7 @@ class TaskBLL(object):
|
|||||||
]
|
]
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
result = next(
|
result = next(Task.aggregate(*pipeline), None)
|
||||||
Task.aggregate(*pipeline),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
remaining = 0
|
remaining = 0
|
||||||
|
@ -5,3 +5,8 @@ non_responsive_tasks_watchdog {
|
|||||||
# Watchdog will sleep for this number of seconds after each cycle
|
# Watchdog will sleep for this number of seconds after each cycle
|
||||||
watch_interval_sec: 900
|
watch_interval_sec: 900
|
||||||
}
|
}
|
||||||
|
|
||||||
|
artifacts {
|
||||||
|
update_attempts: 10
|
||||||
|
update_retry_msec: 500
|
||||||
|
}
|
@ -18,6 +18,7 @@ from database.fields import (
|
|||||||
SafeSortedListField,
|
SafeSortedListField,
|
||||||
)
|
)
|
||||||
from database.model import AttributedDocument
|
from database.model import AttributedDocument
|
||||||
|
from database.model.base import ProperDictMixin
|
||||||
from database.model.model_labels import ModelLabels
|
from database.model.model_labels import ModelLabels
|
||||||
from database.model.project import Project
|
from database.model.project import Project
|
||||||
from database.utils import get_options
|
from database.utils import get_options
|
||||||
@ -78,7 +79,7 @@ class Artifact(EmbeddedDocument):
|
|||||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||||
|
|
||||||
|
|
||||||
class Execution(EmbeddedDocument):
|
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||||
test_split = IntField(default=0)
|
test_split = IntField(default=0)
|
||||||
parameters = SafeDictField(default=dict)
|
parameters = SafeDictField(default=dict)
|
||||||
model = StringField(reference_field="Model")
|
model = StringField(reference_field="Model")
|
||||||
|
@ -1304,4 +1304,40 @@ ping {
|
|||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
add_or_update_artifacts {
|
||||||
|
"2.6" {
|
||||||
|
description: """ Update an existing artifact (search by key/mode) or add a new one """
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [ task, artifacts ]
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
artifacts {
|
||||||
|
description: "Artifacts to add or update"
|
||||||
|
type: array
|
||||||
|
items { "$ref": "#/definitions/artifact" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
added {
|
||||||
|
description: "Keys of artifacts added"
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
}
|
||||||
|
updated {
|
||||||
|
description: "Keys of artifacts updated"
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -34,7 +34,7 @@ class ServiceRepo(object):
|
|||||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||||
maximum """
|
maximum """
|
||||||
|
|
||||||
_max_version = PartialVersion("2.4")
|
_max_version = PartialVersion("2.6")
|
||||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||||
|
|
||||||
_endpoint_exp = (
|
_endpoint_exp = (
|
||||||
|
@ -27,6 +27,9 @@ from apimodels.tasks import (
|
|||||||
EnqueueRequest,
|
EnqueueRequest,
|
||||||
EnqueueResponse,
|
EnqueueResponse,
|
||||||
DequeueResponse,
|
DequeueResponse,
|
||||||
|
CloneRequest,
|
||||||
|
AddOrUpdateArtifactsRequest,
|
||||||
|
AddOrUpdateArtifactsResponse,
|
||||||
)
|
)
|
||||||
from bll.event import EventBLL
|
from bll.event import EventBLL
|
||||||
from bll.queue import QueueBLL
|
from bll.queue import QueueBLL
|
||||||
@ -837,3 +840,18 @@ def ping(_, company_id, request: PingRequest):
|
|||||||
TaskBLL.set_last_update(
|
TaskBLL.set_last_update(
|
||||||
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"tasks.add_or_update_artifacts",
|
||||||
|
min_version="2.6",
|
||||||
|
request_data_model=AddOrUpdateArtifactsRequest,
|
||||||
|
response_data_model=AddOrUpdateArtifactsResponse,
|
||||||
|
)
|
||||||
|
def add_or_update_artifacts(
|
||||||
|
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||||
|
):
|
||||||
|
added, updated = TaskBLL.add_or_update_artifacts(
|
||||||
|
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
||||||
|
)
|
||||||
|
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|
||||||
|
Loading…
Reference in New Issue
Block a user