mirror of
https://github.com/clearml/clearml
synced 2025-03-12 06:41:17 +00:00
Add full artifacts support
This commit is contained in:
parent
27ca36687a
commit
d7bdc746b8
38
examples/artifacts_toy.py
Normal file
38
examples/artifacts_toy.py
Normal file
@ -0,0 +1,38 @@
|
||||
from time import sleep
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from trains import Task
|
||||
|
||||
task = Task.init('examples', 'artifacts toy')
|
||||
|
||||
df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
|
||||
'num_wings': [2, 0, 0, 0],
|
||||
'num_specimen_seen': [10, 2, 1, 8]},
|
||||
index=['falcon', 'dog', 'spider', 'fish'])
|
||||
|
||||
# register Pandas object as artifact to watch
|
||||
# (it will be monitored in the background and automatically synced and uploaded)
|
||||
task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69})
|
||||
# change the artifact object
|
||||
df.sample(frac=0.5, replace=True, random_state=1)
|
||||
# or access it from anywhere using the Task
|
||||
Task.current_task().artifacts['train'].sample(frac=0.5, replace=True, random_state=1)
|
||||
|
||||
# add and upload local file artifact
|
||||
task.upload_artifact('local file', artifact_object='samples/dancing.jpg')
|
||||
# add and upload dictionary stored as JSON)
|
||||
task.upload_artifact('dictionary', df.to_dict())
|
||||
# add and upload Numpy Object (stored as .npz file)
|
||||
task.upload_artifact('Numpy Eye', np.eye(100, 100))
|
||||
# add and upload Image (stored as .png file)
|
||||
im = Image.open('samples/dancing.jpg')
|
||||
task.upload_artifact('pillow_image', im)
|
||||
|
||||
# do something
|
||||
sleep(1.)
|
||||
print(df)
|
||||
|
||||
# we are done
|
||||
print('Done')
|
@ -186,10 +186,12 @@ class UploadEvent(MetricsEventAdapter):
|
||||
self._count = self._get_metric_count(metric, variant)
|
||||
if not image_file_history_size:
|
||||
image_file_history_size = self._image_file_history_size
|
||||
if image_file_history_size < 1:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
|
||||
else:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
|
||||
self._filename = kwargs.pop('override_filename', None)
|
||||
if not self._filename:
|
||||
if image_file_history_size < 1:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
|
||||
else:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
|
||||
self._upload_uri = upload_uri
|
||||
self._delete_after_upload = delete_after_upload
|
||||
|
||||
@ -198,6 +200,9 @@ class UploadEvent(MetricsEventAdapter):
|
||||
image_format = self._format.lower() if self._image_data is not None else \
|
||||
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
|
||||
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
|
||||
|
||||
self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None)
|
||||
|
||||
super(UploadEvent, self).__init__(metric, variant, iter=iter, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@ -273,10 +278,12 @@ class UploadEvent(MetricsEventAdapter):
|
||||
delete_local_file=local_file if self._delete_after_upload else None,
|
||||
)
|
||||
|
||||
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
|
||||
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix=None):
|
||||
e_storage_uri = self._upload_uri or storage_uri
|
||||
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
||||
filename = self._upload_filename
|
||||
if self._override_storage_key_prefix or not storage_key_prefix:
|
||||
storage_key_prefix = self._override_storage_key_prefix
|
||||
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
|
||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||
return key, url
|
||||
|
@ -6,6 +6,7 @@ from tempfile import mkstemp
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..backend_api import Session
|
||||
from ..backend_api.services import models
|
||||
from .base import IdObjectBase
|
||||
from .util import make_message
|
||||
@ -185,18 +186,21 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
design = self._wrap_design(design) if design else self.data.design
|
||||
name = name or self.data.name
|
||||
comment = comment or self.data.comment
|
||||
tags = tags or self.data.tags
|
||||
labels = labels or self.data.labels
|
||||
task = task_id or self.data.task
|
||||
project = project_id or self.data.project
|
||||
parent = parent_id or self.data.parent
|
||||
if tags:
|
||||
extra = {'system_tags': tags or self.data.system_tags} \
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||
else:
|
||||
extra = {}
|
||||
|
||||
self.send(models.EditRequest(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
labels=labels,
|
||||
design=design,
|
||||
task=task,
|
||||
@ -204,6 +208,27 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
parent=parent,
|
||||
framework=framework or self.data.framework,
|
||||
iteration=iteration,
|
||||
**extra
|
||||
))
|
||||
self.reload()
|
||||
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None):
|
||||
if tags:
|
||||
extra = {'system_tags': tags or self.data.system_tags} \
|
||||
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||
else:
|
||||
extra = {}
|
||||
self.send(models.EditRequest(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
name=name,
|
||||
comment=comment,
|
||||
labels=labels,
|
||||
design=self._wrap_design(design) if design else None,
|
||||
framework=framework,
|
||||
iteration=iteration,
|
||||
**extra
|
||||
))
|
||||
self.reload()
|
||||
|
||||
@ -239,7 +264,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
if cb:
|
||||
cb(model_file)
|
||||
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback)
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename,
|
||||
cb=callback)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||
@ -264,12 +290,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
if self._data:
|
||||
name = name or self.data.name
|
||||
comment = comment or self.data.comment
|
||||
tags = tags or self.data.tags
|
||||
tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags)
|
||||
uri = (uri or self.data.uri) if not override_model_id else None
|
||||
|
||||
if tags:
|
||||
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||
else:
|
||||
extra = {}
|
||||
res = self.send(
|
||||
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id))
|
||||
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
|
||||
override_model_id=override_model_id, **extra))
|
||||
if self.id is None:
|
||||
# update the model id. in case it was just created, this will trigger a reload of the model object
|
||||
self.id = res.response.id
|
||||
@ -295,8 +325,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
else:
|
||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
|
||||
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
|
||||
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id, iteration=iteration))
|
||||
if tags:
|
||||
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||
else:
|
||||
extra = {}
|
||||
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
|
||||
override_model_id=override_model_id, iteration=iteration,
|
||||
**extra))
|
||||
return uri
|
||||
|
||||
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
|
||||
@ -337,7 +372,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return self.data.tags
|
||||
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
@ -402,18 +437,20 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
data = self.data
|
||||
assert isinstance(data, models.Model)
|
||||
parent = self.id if child else None
|
||||
extra = {'system_tags': tags or data.system_tags} \
|
||||
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
|
||||
req = models.CreateRequest(
|
||||
uri=data.uri,
|
||||
name=name,
|
||||
labels=data.labels,
|
||||
comment=comment or data.comment,
|
||||
tags=tags or data.tags,
|
||||
framework=data.framework,
|
||||
design=data.design,
|
||||
ready=ready,
|
||||
project=data.project,
|
||||
parent=parent,
|
||||
task=task,
|
||||
**extra
|
||||
)
|
||||
res = self.send(req)
|
||||
return res.response.id
|
||||
|
@ -6,6 +6,7 @@ from enum import Enum
|
||||
from threading import RLock, Thread
|
||||
|
||||
import six
|
||||
from six.moves.urllib.parse import quote
|
||||
|
||||
from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
@ -96,6 +97,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# this is an existing task, let's try to verify stuff
|
||||
self._validate()
|
||||
|
||||
self._project_name = (self.project, project_name)
|
||||
|
||||
if running_remotely() or DevWorker.report_stdout:
|
||||
log_to_backend = False
|
||||
self._log_to_backend = log_to_backend
|
||||
@ -223,14 +226,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
project_id = get_or_create_project(self, project_name, created_msg)
|
||||
|
||||
tags = [self._development_tag] if not running_remotely() else []
|
||||
|
||||
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||
req = tasks.CreateRequest(
|
||||
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
|
||||
type=tasks.TaskTypeEnum(task_type.value),
|
||||
comment=created_msg,
|
||||
project=project_id,
|
||||
input={'view': {}},
|
||||
tags=tags,
|
||||
**extra_properties
|
||||
)
|
||||
res = self.send(req)
|
||||
|
||||
@ -369,7 +372,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return self._reporter
|
||||
|
||||
def _get_output_destination_suffix(self, extra_path=None):
|
||||
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
|
||||
return '/'.join(quote(x, safe='[]{}()$^,.; -_+-=') for x in
|
||||
(self.get_project_name(), '%s.%s' % (self.name, self.data.id), extra_path) if x)
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the task object from the backend """
|
||||
@ -427,9 +431,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
|
||||
"""
|
||||
Update the task's output model.
|
||||
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
|
||||
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
|
||||
providing a file://path/to/file URI)
|
||||
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
|
||||
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
|
||||
providing a file://path/to/file URI)
|
||||
|
||||
:param model_uri: URI for the updated model weights file
|
||||
:type model_uri: str
|
||||
:param name: Optional updated model name
|
||||
@ -446,8 +451,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
self, model_file, name=None, comment=None, tags=None, async_enable=False, cb=None, iteration=None):
|
||||
"""
|
||||
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
|
||||
task's output.destination property or call setup_upload()), than the model object associated with the task is
|
||||
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
|
||||
task's output.destination property or call setup_upload()), than the model object associated with the task is
|
||||
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
|
||||
|
||||
:param model_file: Path to the updated model weights file
|
||||
:type model_file: str
|
||||
:param name: Optional updated model name
|
||||
@ -632,6 +638,21 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
execution.model_labels = enumeration
|
||||
self._edit(execution=execution)
|
||||
|
||||
def set_artifacts(self, artifacts_list=None):
|
||||
"""
|
||||
List of artifacts (tasks.Artifact) to update the task
|
||||
|
||||
:param list artifacts_list: list of artifacts (type tasks.Artifact)
|
||||
"""
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
return False
|
||||
if not (isinstance(artifacts_list, (list, tuple))
|
||||
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
|
||||
raise ValueError('Expected artifacts to [tasks.Artifacts]')
|
||||
execution = self.data.execution
|
||||
execution.artifacts = artifacts_list
|
||||
self._edit(execution=execution)
|
||||
|
||||
def _set_model_design(self, design=None):
|
||||
execution = self.data.execution
|
||||
if design is not None:
|
||||
@ -677,7 +698,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if self.project is None:
|
||||
return None
|
||||
|
||||
if self._project_name and self._project_name[0] == self.project:
|
||||
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
|
||||
return self._project_name[1]
|
||||
|
||||
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
||||
@ -689,8 +710,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def get_tags(self):
|
||||
return self._get_task_property("tags")
|
||||
|
||||
def set_system_tags(self, tags):
|
||||
assert isinstance(tags, (list, tuple))
|
||||
if Session.check_min_api_version('2.3'):
|
||||
self._set_task_property("system_tags", tags)
|
||||
self._edit(system_tags=self.data.system_tags)
|
||||
else:
|
||||
self._set_task_property("tags", tags)
|
||||
self._edit(tags=self.data.tags)
|
||||
|
||||
def set_tags(self, tags):
|
||||
assert isinstance(tags, (list, tuple))
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
# not supported
|
||||
return
|
||||
self._set_task_property("tags", tags)
|
||||
self._edit(tags=self.data.tags)
|
||||
|
||||
|
@ -25,7 +25,7 @@ def get_or_create_project(session, project_name, description=None):
|
||||
return res.response.id
|
||||
|
||||
|
||||
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True):
|
||||
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):
|
||||
if not results:
|
||||
if not raise_on_error:
|
||||
return None
|
||||
@ -38,6 +38,13 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
|
||||
if len(results) > 1:
|
||||
log.warn('More than one {entity} found when searching for `{query}`'
|
||||
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
||||
if sort_by_date:
|
||||
# sort results based on timestamp and return the newest one
|
||||
if hasattr(results[0], 'last_update'):
|
||||
results = sorted(results, key=lambda x: int(x.last_update.strftime('%s')), reverse=True)
|
||||
elif hasattr(results[0], 'created'):
|
||||
results = sorted(results, key=lambda x: int(x.created.strftime('%s')), reverse=True)
|
||||
|
||||
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
|
||||
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
||||
|
||||
|
@ -1,18 +1,31 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from tempfile import mkdtemp
|
||||
from threading import Thread, Event
|
||||
from tempfile import mkdtemp, mkstemp
|
||||
from threading import Thread, Event, RLock
|
||||
from time import time
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
from PIL import Image
|
||||
|
||||
from ..backend_interface.metrics.events import UploadEvent
|
||||
from ..backend_api import Session
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..backend_api.services import tasks
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
np = None
|
||||
|
||||
|
||||
class Artifacts(object):
|
||||
@ -22,6 +35,7 @@ class Artifacts(object):
|
||||
_compression = 'gzip'
|
||||
# hashing constants
|
||||
_hash_block_size = 65536
|
||||
_pd_artifact_type = 'data-audit-table'
|
||||
|
||||
class _ProxyDictWrite(dict):
|
||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||
@ -33,7 +47,7 @@ class Artifacts(object):
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# check that value is of type pandas
|
||||
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
|
||||
if pd and isinstance(value, pd.DataFrame):
|
||||
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
||||
|
||||
if self._artifacts_manager:
|
||||
@ -72,6 +86,9 @@ class Artifacts(object):
|
||||
self._thread_pool = ThreadPool()
|
||||
self._summary = ''
|
||||
self._temp_folder = []
|
||||
self._task_artifact_list = []
|
||||
self._task_edit_lock = RLock()
|
||||
self._storage_prefix = None
|
||||
|
||||
def register_artifact(self, name, artifact, metadata=None):
|
||||
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
||||
@ -86,6 +103,94 @@ class Artifacts(object):
|
||||
self._unregister_request.add(name)
|
||||
self.flush()
|
||||
|
||||
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
|
||||
'please upgrade to the latest server version')
|
||||
return False
|
||||
|
||||
if name in self._artifacts_dict:
|
||||
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
|
||||
|
||||
artifact_type_data = tasks.ArtifactTypeData()
|
||||
use_filename_in_uri = True
|
||||
if np and isinstance(artifact_object, np.ndarray):
|
||||
artifact_type = 'numpy'
|
||||
artifact_type_data.content_type = 'application/numpy'
|
||||
artifact_type_data.preview = str(artifact_object.__repr__())
|
||||
fd, local_filename = mkstemp(suffix='.npz')
|
||||
os.close(fd)
|
||||
np.savez_compressed(local_filename, **{name: artifact_object})
|
||||
delete_after_upload = True
|
||||
use_filename_in_uri = False
|
||||
elif isinstance(artifact_object, Image.Image):
|
||||
artifact_type = 'image'
|
||||
artifact_type_data.content_type = 'image/png'
|
||||
desc = str(artifact_object.__repr__())
|
||||
artifact_type_data.preview = desc[1:desc.find(' at ')]
|
||||
fd, local_filename = mkstemp(suffix='.png')
|
||||
os.close(fd)
|
||||
artifact_object.save(local_filename)
|
||||
delete_after_upload = True
|
||||
use_filename_in_uri = False
|
||||
elif isinstance(artifact_object, dict):
|
||||
artifact_type = 'JSON'
|
||||
artifact_type_data.content_type = 'application/json'
|
||||
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
|
||||
fd, local_filename = mkstemp(suffix='.json')
|
||||
os.write(fd, bytes(preview.encode()))
|
||||
os.close(fd)
|
||||
artifact_type_data.preview = preview
|
||||
delete_after_upload = True
|
||||
use_filename_in_uri = False
|
||||
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
|
||||
if isinstance(artifact_object, Path):
|
||||
artifact_object = artifact_object.as_posix()
|
||||
artifact_type = 'custom'
|
||||
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
|
||||
local_filename = artifact_object
|
||||
else:
|
||||
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
|
||||
|
||||
# remove from existing list, if exists
|
||||
for artifact in self._task_artifact_list:
|
||||
if artifact.key == name:
|
||||
if artifact.type == self._pd_artifact_type:
|
||||
raise ValueError("Artifact of name {} already registered, "
|
||||
"use register_artifact instead".format(name))
|
||||
|
||||
self._task_artifact_list.remove(artifact)
|
||||
break
|
||||
|
||||
# check that the file to upload exists
|
||||
local_filename = Path(local_filename).absolute()
|
||||
if not local_filename.exists() or not local_filename.is_file():
|
||||
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
|
||||
local_filename.as_posix()))
|
||||
return False
|
||||
|
||||
file_hash, _ = self.sha256sum(local_filename.as_posix())
|
||||
timestamp = int(time())
|
||||
file_size = local_filename.stat().st_size
|
||||
|
||||
uri = self._upload_local_file(local_filename, name,
|
||||
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
|
||||
|
||||
artifact = tasks.Artifact(key=name, type=artifact_type,
|
||||
uri=uri,
|
||||
content_size=file_size,
|
||||
hash=file_hash,
|
||||
timestamp=timestamp,
|
||||
type_data=artifact_type_data,
|
||||
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
|
||||
|
||||
# update task artifacts
|
||||
with self._task_edit_lock:
|
||||
self._task_artifact_list.append(artifact)
|
||||
self._task.set_artifacts(self._task_artifact_list)
|
||||
|
||||
return True
|
||||
|
||||
def flush(self):
|
||||
# start the thread if it hasn't already:
|
||||
self._start()
|
||||
@ -119,19 +224,20 @@ class Artifacts(object):
|
||||
while not self._exit_flag:
|
||||
self._flush_event.wait(self._flush_frequency_sec)
|
||||
self._flush_event.clear()
|
||||
try:
|
||||
artifact_keys = list(self._artifacts_dict.keys())
|
||||
for name in artifact_keys:
|
||||
self._upload_artifacts(name)
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(str(e))
|
||||
artifact_keys = list(self._artifacts_dict.keys())
|
||||
for name in artifact_keys:
|
||||
try:
|
||||
self._upload_data_audit_artifacts(name)
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(str(e))
|
||||
|
||||
# create summary
|
||||
self._summary = self._get_statistics()
|
||||
|
||||
def _upload_artifacts(self, name):
|
||||
def _upload_data_audit_artifacts(self, name):
|
||||
logger = self._task.get_logger()
|
||||
artifact = self._artifacts_dict.get(name)
|
||||
pd_artifact = self._artifacts_dict.get(name)
|
||||
pd_metadata = self._artifacts_dict.get_metadata(name)
|
||||
|
||||
# remove from artifacts watch list
|
||||
if name in self._unregister_request:
|
||||
@ -141,15 +247,15 @@ class Artifacts(object):
|
||||
pass
|
||||
self._artifacts_dict.unregister_artifact(name)
|
||||
|
||||
if artifact is None:
|
||||
if pd_artifact is None:
|
||||
return
|
||||
|
||||
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
||||
if local_csv.exists():
|
||||
# we are still uploading... get another temp folder
|
||||
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
||||
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
||||
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
||||
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
||||
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
||||
if name in self._last_artifacts_upload:
|
||||
previous_sha2 = self._last_artifacts_upload[name]
|
||||
if previous_sha2 == current_sha2:
|
||||
@ -157,19 +263,75 @@ class Artifacts(object):
|
||||
local_csv.unlink()
|
||||
return
|
||||
self._last_artifacts_upload[name] = current_sha2
|
||||
# now upload and delete at the end.
|
||||
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
||||
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
||||
max_image_history=2)
|
||||
|
||||
def _get_statistics(self):
|
||||
# If old trains-server, upload as debug image
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
||||
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
||||
max_image_history=2)
|
||||
return
|
||||
|
||||
# Find our artifact
|
||||
artifact = None
|
||||
for an_artifact in self._task_artifact_list:
|
||||
if an_artifact.key == name:
|
||||
artifact = an_artifact
|
||||
break
|
||||
|
||||
file_size = local_csv.stat().st_size
|
||||
|
||||
# upload file
|
||||
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
|
||||
|
||||
# update task artifacts
|
||||
with self._task_edit_lock:
|
||||
if not artifact:
|
||||
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
|
||||
self._task_artifact_list.append(artifact)
|
||||
artifact_type_data = tasks.ArtifactTypeData()
|
||||
|
||||
artifact_type_data.data_hash = current_sha2
|
||||
artifact_type_data.content_type = "text/csv"
|
||||
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
|
||||
|
||||
artifact.type_data = artifact_type_data
|
||||
artifact.uri = uri
|
||||
artifact.content_size = file_size
|
||||
artifact.hash = file_sha2
|
||||
artifact.timestamp = int(time())
|
||||
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
|
||||
|
||||
self._task.set_artifacts(self._task_artifact_list)
|
||||
|
||||
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
|
||||
"""
|
||||
Upload local file and return uri of the uploaded file (uploading in the background)
|
||||
"""
|
||||
upload_uri = self._task.get_logger().get_default_upload_destination()
|
||||
if not isinstance(local_file, Path):
|
||||
local_file = Path(local_file)
|
||||
ev = UploadEvent(metric='artifacts', variant=name,
|
||||
image_data=None, upload_uri=upload_uri,
|
||||
local_image_path=local_file.as_posix(),
|
||||
delete_after_upload=delete_after_upload,
|
||||
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
|
||||
override_storage_key_prefix=self._get_storage_uri_prefix())
|
||||
_, uri = ev.get_target_full_upload_uri(upload_uri)
|
||||
|
||||
# send for upload
|
||||
self._task.reporter._report(ev)
|
||||
|
||||
return uri
|
||||
|
||||
def _get_statistics(self, artifacts_dict=None):
|
||||
summary = ''
|
||||
artifacts_dict = artifacts_dict or self._artifacts_dict
|
||||
thread_pool = ThreadPool()
|
||||
|
||||
try:
|
||||
# build hash row sets
|
||||
artifacts_summary = []
|
||||
for a_name, a_df in self._artifacts_dict.items():
|
||||
for a_name, a_df in artifacts_dict.items():
|
||||
if not pd or not isinstance(a_df, pd.DataFrame):
|
||||
continue
|
||||
|
||||
@ -206,16 +368,30 @@ class Artifacts(object):
|
||||
return new_temp
|
||||
return self._temp_folder[0]
|
||||
|
||||
def _get_storage_uri_prefix(self):
|
||||
if not self._storage_prefix:
|
||||
self._storage_prefix = self._task._get_output_destination_suffix()
|
||||
return self._storage_prefix
|
||||
|
||||
@staticmethod
|
||||
def sha256sum(filename, skip_header=0):
|
||||
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
||||
# because sometimes that is the only change
|
||||
h = hashlib.sha256()
|
||||
file_hash = hashlib.sha256()
|
||||
b = bytearray(Artifacts._hash_block_size)
|
||||
mv = memoryview(b)
|
||||
with open(filename, 'rb', buffering=0) as f:
|
||||
# skip header
|
||||
f.read(skip_header)
|
||||
for n in iter(lambda: f.readinto(mv), 0):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
||||
try:
|
||||
with open(filename, 'rb', buffering=0) as f:
|
||||
# skip header
|
||||
if skip_header:
|
||||
file_hash.update(f.read(skip_header))
|
||||
for n in iter(lambda: f.readinto(mv), 0):
|
||||
h.update(mv[:n])
|
||||
if skip_header:
|
||||
file_hash.update(mv[:n])
|
||||
except Exception as e:
|
||||
LoggerRoot.get_base_logger().warning(str(e))
|
||||
return None, None
|
||||
|
||||
return h.hexdigest(), file_hash.hexdigest() if skip_header else None
|
||||
|
@ -109,7 +109,7 @@ class PatchedMatplotlib:
|
||||
|
||||
# update api version
|
||||
from ..backend_api import Session
|
||||
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
||||
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
|
||||
|
||||
# create plotly renderer
|
||||
try:
|
||||
|
@ -572,7 +572,7 @@ class Logger(object):
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
@ -619,7 +619,7 @@ class Logger(object):
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
@ -664,7 +664,7 @@ class Logger(object):
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
@ -705,6 +705,19 @@ class Logger(object):
|
||||
|
||||
self._default_upload_destination = uri
|
||||
|
||||
def get_default_upload_destination(self):
|
||||
"""
|
||||
Get the uri to upload all the debug images to.
|
||||
|
||||
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||
a link to the uploaded image is sent in the report
|
||||
Notice: credentials for the upload destination will be pooled from the
|
||||
global configuration file (i.e. ~/trains.conf)
|
||||
|
||||
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
|
||||
"""
|
||||
return self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush cached reports and console outputs to backend.
|
||||
|
@ -6,6 +6,8 @@ from tempfile import mkdtemp, mkstemp
|
||||
|
||||
import pyparsing
|
||||
import six
|
||||
|
||||
from .backend_api import Session
|
||||
from .backend_api.services import models
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigFactory, HOCONConverter
|
||||
@ -276,10 +278,13 @@ class BaseModel(object):
|
||||
def _set_package_tag(self):
|
||||
if self._package_tag not in self.tags:
|
||||
self.tags.append(self._package_tag)
|
||||
self._get_base_model().update(tags=self.tags)
|
||||
self._get_base_model().edit(tags=self.tags)
|
||||
|
||||
@staticmethod
|
||||
def _config_dict_to_text(config):
|
||||
# if already string return as is
|
||||
if isinstance(config, six.string_types):
|
||||
return config
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("Model configuration only supports dictionary objects")
|
||||
try:
|
||||
@ -372,10 +377,12 @@ class InputModel(BaseModel):
|
||||
weights_url = StorageHelper.conform_url(weights_url)
|
||||
if not weights_url:
|
||||
raise ValueError("Please provide a valid weights_url parameter")
|
||||
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||
uri=[weights_url],
|
||||
only_fields=["id", "name"],
|
||||
tags=["-" + ARCHIVED_TAG]
|
||||
**extra
|
||||
))
|
||||
|
||||
if result.response.models:
|
||||
@ -910,7 +917,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
if self.id:
|
||||
# update the model object (this will happen if we resumed a training task)
|
||||
result = self._get_force_base_model().update(design=config_text, task_id=self._task.id)
|
||||
result = self._get_force_base_model().edit(design=config_text)
|
||||
else:
|
||||
self._floating_data.design = _Model._wrap_design(config_text)
|
||||
result = Waitable()
|
||||
@ -936,7 +943,7 @@ class OutputModel(BaseModel):
|
||||
|
||||
if self.id:
|
||||
# update the model object (this will happen if we resumed a training task)
|
||||
result = self._get_force_base_model().update(labels=labels, task_id=self._task.id)
|
||||
result = self._get_force_base_model().edit(labels=labels)
|
||||
else:
|
||||
self._floating_data.labels = labels
|
||||
result = Waitable()
|
||||
|
@ -309,7 +309,7 @@ class Task(_Task):
|
||||
If project is None, and the main execution task is initialized (Task.init), its project will be used.
|
||||
If project is provided but doesn't exist, it will be created.
|
||||
:param task_type: Task type to be created. (default: "training")
|
||||
Optional Task types are: "training" / "testing" / "dataset_import" / "annotation" / "annotation_manual"
|
||||
Optional Task types are: "training" / "testing" / "dataset_import" / "annotation" / "annotation_manual"
|
||||
:return: Task() object
|
||||
"""
|
||||
if not project_name:
|
||||
@ -390,8 +390,10 @@ class Task(_Task):
|
||||
task_id=default_task_id,
|
||||
log_to_backend=True,
|
||||
)
|
||||
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
||||
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
|
||||
or task.output_model_id or (ARCHIVED_TAG in task_tags)
|
||||
or (cls._development_tag not in task_tags)):
|
||||
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
|
||||
# If the task is archived, or already has an output model,
|
||||
# we shouldn't use it in development mode either
|
||||
@ -401,7 +403,7 @@ class Task(_Task):
|
||||
# reset the task, so we can update it
|
||||
task.reset(set_started_on_success=False, force=False)
|
||||
# set development tags
|
||||
task.set_tags([cls._development_tag])
|
||||
task.set_system_tags([cls._development_tag])
|
||||
# clear task parameters, they are not cleared by the Task reset
|
||||
task.set_parameters({}, __update=False)
|
||||
# clear the comment, it is not cleared on reset
|
||||
@ -410,6 +412,7 @@ class Task(_Task):
|
||||
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
|
||||
task.set_model_config(config_text='')
|
||||
task.set_model_label_enumeration({})
|
||||
task.set_artifacts([])
|
||||
|
||||
except (Exception, ValueError):
|
||||
# we failed reusing task, create a new one
|
||||
@ -480,7 +483,7 @@ class Task(_Task):
|
||||
if value and value != self.storage_uri:
|
||||
from .storage.helper import StorageHelper
|
||||
helper = StorageHelper.get(value)
|
||||
helper.check_write_permissions()
|
||||
helper.check_write_permissions(value)
|
||||
self.storage_uri = value
|
||||
|
||||
@property
|
||||
@ -658,20 +661,24 @@ class Task(_Task):
|
||||
"""
|
||||
self._artifacts_manager.unregister_artifact(name=name)
|
||||
|
||||
def upload_artifact(self, name, artifact_object=None, artifact_file=None, metadata=None):
|
||||
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
|
||||
"""
|
||||
Add static artifact to Task. Artifact file/object will be uploaded in the background
|
||||
Raise ValueError if artifact_object is not supported
|
||||
|
||||
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
|
||||
:param object artifact_object: Artifact object to upload. Currently supports Numpy, PIL.Image.
|
||||
Numpy will be stored as .npz, and Image as .png file.
|
||||
Use None if uploading a file directly with 'artifact_file'.
|
||||
:param str artifact_file: path to artifact file to upload. None means not applicable.
|
||||
Notice you wither artifact object or artifact_file
|
||||
:param object artifact_object: Artifact object to upload. Currently supports:
|
||||
- string / pathlib2.Path are treated as path to artifact file to upload
|
||||
- dict will be stored as .json,
|
||||
- numpy.ndarray will be stored as .npz,
|
||||
- PIL.Image will be stored to .png file and uploaded
|
||||
:param dict metadata: Simple key/value dictionary to store on the artifact
|
||||
:return: True if artifact is supported
|
||||
:param bool delete_after_upload: If True local artifact will be deleted
|
||||
(only applies if artifact_object is a local file)
|
||||
:return: True if artifact will be uploaded
|
||||
"""
|
||||
raise ValueError("Not implemented yet")
|
||||
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
||||
metadata=metadata, delete_after_upload=delete_after_upload)
|
||||
|
||||
def is_current_task(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user