Improve Task.artifacts for safer Task multi node use cases

This commit is contained in:
allegroai 2021-04-25 10:37:14 +03:00
parent d3cdb0ac42
commit 1637c1a270
2 changed files with 46 additions and 8 deletions

View File

@ -1188,23 +1188,55 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._get_task_property("execution.docker_cmd", raise_on_error=False, log_on_error=False)
def set_artifacts(self, artifacts_list=None):
# type: (Sequence[tasks.Artifact]) -> ()
# type: (Sequence[tasks.Artifact]) -> Optional[List[tasks.Artifact]]
"""
List of artifacts (tasks.Artifact) to update the task
:param list artifacts_list: list of artifacts (type tasks.Artifact)
:return: List of current Task's Artifacts or None if error.
"""
if not Session.check_min_api_version('2.3'):
return False
return None
if not (isinstance(artifacts_list, (list, tuple))
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
raise ValueError('Expected artifacts to [tasks.Artifacts]')
raise ValueError('Expected artifacts as List[tasks.Artifact]')
with self._edit_lock:
self.reload()
execution = self.data.execution
keys = [a.key for a in artifacts_list]
execution.artifacts = [a for a in execution.artifacts or [] if a.key not in keys] + artifacts_list
self._edit(execution=execution)
return execution.artifacts or []
def _add_artifacts(self, artifacts_list):
# type: (Sequence[tasks.Artifact]) -> Optional[List[tasks.Artifact]]
"""
List of artifacts (tasks.Artifact) to add to the the task
If an artifact by the same name already exists it will overwrite the existing artifact.
:param list artifacts_list: list of artifacts (type tasks.Artifact)
:return: List of current Task's Artifacts
"""
if not Session.check_min_api_version('2.3'):
return None
if not (isinstance(artifacts_list, (list, tuple))
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
raise ValueError('Expected artifacts as List[tasks.Artifact]')
with self._edit_lock:
if Session.check_min_api_version("2.13") and not self._offline_mode:
req = tasks.AddOrUpdateArtifactsRequest(task=self.task_id, artifacts=artifacts_list, force=True)
res = self.send(req, raise_on_errors=False)
if not res or not res.response or not res.response.updated:
return None
self.reload()
else:
self.reload()
execution = self.data.execution
keys = [a.key for a in artifacts_list]
execution.artifacts = [a for a in execution.artifacts or [] if a.key not in keys] + artifacts_list
self._edit(execution=execution)
return self.data.execution.artifacts or []
def _set_model_design(self, design=None):
# type: (str) -> ()

View File

@ -564,9 +564,7 @@ class Artifacts(object):
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
# update task artifacts
with self._task_edit_lock:
self._task_artifact_list.append(artifact)
self._task.set_artifacts(self._task_artifact_list)
self._add_artifact(artifact)
return True
@ -619,6 +617,15 @@ class Artifacts(object):
# create summary
self._summary = self._get_statistics()
def _add_artifact(self, artifact):
if not self._task:
raise ValueError("Task object not set")
with self._task_edit_lock:
if artifact not in self._task_artifact_list:
self._task_artifact_list.append(artifact)
# noinspection PyProtectedMember
self._task._add_artifacts(self._task_artifact_list)
def _upload_data_audit_artifacts(self, name):
# type: (str) -> ()
logger = self._task.get_logger()
@ -681,7 +688,6 @@ class Artifacts(object):
with self._task_edit_lock:
if not artifact:
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
self._task_artifact_list.append(artifact)
artifact_type_data = tasks.ArtifactTypeData()
artifact_type_data.data_hash = current_sha2
@ -696,7 +702,7 @@ class Artifacts(object):
artifact.timestamp = int(time())
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
self._task.set_artifacts(self._task_artifact_list)
self._add_artifact(artifact)
def _upload_local_file(
self, local_file, name, delete_after_upload=False, override_filename=None, override_filename_ext=None,