Add Full Artifacts support, including local files/folders, pandas, numpy, images. requires train-server 0.11 and above

This commit is contained in:
allegroai 2019-09-23 18:41:53 +03:00
parent 4f1eeb49c6
commit b1d3fc9694
2 changed files with 187 additions and 29 deletions

View File

@ -17,8 +17,8 @@ df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69})
# change the artifact object
df.sample(frac=0.5, replace=True, random_state=1)
# or access it from anywhere using the Task
Task.current_task().artifacts['train'].sample(frac=0.5, replace=True, random_state=1)
# or access it from anywhere using the Task's get_registered_artifacts()
Task.current_task().get_registered_artifacts()['train'].sample(frac=0.5, replace=True, random_state=1)
# add and upload pandas.DataFrame (onetime snapshot of the object)
task.upload_artifact('Pandas', artifact_object=df)
@ -32,6 +32,7 @@ task.upload_artifact('Numpy Eye', np.eye(100, 100))
im = Image.open('samples/dancing.jpg')
task.upload_artifact('pillow_image', im)
# do something
sleep(1.)
print(df)

View File

@ -1,14 +1,16 @@
import hashlib
import json
import logging
import mimetypes
import os
from zipfile import ZipFile, ZIP_DEFLATED
from copy import deepcopy
from datetime import datetime
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp
from threading import Thread, Event, RLock
from time import time
import humanfriendly
import six
from pathlib2 import Path
from PIL import Image
@ -28,6 +30,103 @@ except ImportError:
np = None
class Artifact(object):
"""
Read-Only Artifact object
"""
@property
def url(self):
"""
:return: url of uploaded artifact
"""
return self._url
@property
def name(self):
"""
:return: name of artifact
"""
return self._name
@property
def size(self):
"""
:return: size in bytes of artifact
"""
return self._size
@property
def type(self):
"""
:return: type (str) of of artifact
"""
return self._type
@property
def mode(self):
"""
:return: mode (str) of of artifact. either "input" or "output"
"""
return self._mode
@property
def hash(self):
"""
:return: SHA2 hash (str) of of artifact content.
"""
return self._hash
@property
def timestamp(self):
"""
:return: Timestamp (datetime) of uploaded artifact.
"""
return self._timestamp
@property
def metadata(self):
"""
:return: Key/Value dictionary attached to artifact.
"""
return self._metadata
@property
def preview(self):
"""
:return: string (str) representation of the artifact.
"""
return self._preview
def __init__(self, artifact_api_object):
"""
construct read-only object from api artifact object
:param tasks.Artifact artifact_api_object:
"""
self._name = artifact_api_object.key
self._size = artifact_api_object.content_size
self._type = artifact_api_object.type
self._mode = artifact_api_object.mode
self._url = artifact_api_object.uri
self._hash = artifact_api_object.hash
self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp)
self._metadata = dict(artifact_api_object.display_data) if artifact_api_object.display_data else {}
self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None
def get_local_copy(self):
"""
:return: a local path to a downloaded copy of the artifact
"""
from trains.storage.helper import StorageHelper
return StorageHelper.get_local_copy(self.url)
def __repr__(self):
return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url,
'hash': self.hash, 'timestamp': self.timestamp,
'metadata': self.metadata, 'preview': self.preview,})
class Artifacts(object):
_flush_frequency_sec = 300.
# notice these two should match
@ -66,7 +165,7 @@ class Artifacts(object):
return self.artifact_metadata.get(name)
@property
def artifacts(self):
def registered_artifacts(self):
return self._artifacts_dict
@property
@ -113,51 +212,98 @@ class Artifacts(object):
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
use_filename_in_uri = True
override_filename_in_uri = None
override_filename_ext_in_uri = None
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix='.npz')
override_filename_ext_in_uri = '.npz'
override_filename_in_uri = name+override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
use_filename_in_uri = False
elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = 'pandas'
artifact_type_data.content_type = 'text/csv'
artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix=self._save_format)
override_filename_ext_in_uri = self._save_format
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd)
artifact_object.to_csv(local_filename, compression=self._compression)
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, Image.Image):
artifact_type = 'image'
artifact_type_data.content_type = 'image/png'
desc = str(artifact_object.__repr__())
artifact_type_data.preview = desc[1:desc.find(' at ')]
fd, local_filename = mkstemp(suffix='.png')
override_filename_ext_in_uri = '.png'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd)
artifact_object.save(local_filename)
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, dict):
artifact_type = 'JSON'
artifact_type_data.content_type = 'application/json'
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
fd, local_filename = mkstemp(suffix='.json')
override_filename_ext_in_uri = '.json'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.write(fd, bytes(preview.encode()))
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
if isinstance(artifact_object, Path):
# check if single file
if isinstance(artifact_object, six.string_types):
artifact_object = Path(artifact_object)
artifact_object.expanduser().absolute()
create_zip_file = not artifact_object.is_file()
if artifact_object.is_dir():
# change to wildcard
artifact_object = artifact_object / '*'
if create_zip_file:
folder = Path('').joinpath(*artifact_object.parts[:-1])
wildcard = artifact_object.parts[-1]
files = list(Path(folder).rglob(wildcard))
override_filename_ext_in_uri = '.zip'
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
fd, zip_file = mkstemp(prefix=folder.parts[-1]+'.', suffix=override_filename_ext_in_uri)
try:
artifact_type_data.content_type = 'application/zip'
artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix())
with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf:
for filename in sorted(files):
if filename.is_file():
relative_file_name = filename.relative_to(folder).as_posix()
artifact_type_data.preview += '{} - {}\n'.format(
relative_file_name, humanfriendly.format_size(filename.stat().st_size))
zf.write(filename.as_posix(), arcname=relative_file_name)
except Exception as e:
# failed uploading folder:
LoggerRoot.get_base_logger().warning('Exception {}\nFailed zipping artifact folder {}'.format(
folder, e))
return None
finally:
os.close(fd)
artifact_object = zip_file
artifact_type = 'zip'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
delete_after_upload = True
else:
override_filename_in_uri = artifact_object.parts[-1]
artifact_object = artifact_object.as_posix()
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
else:
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
@ -183,7 +329,9 @@ class Artifacts(object):
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
delete_after_upload=delete_after_upload,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri,
@ -259,25 +407,29 @@ class Artifacts(object):
if pd_artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
override_filename_ext_in_uri = self._save_format
override_filename_in_uri = name
fd, local_csv = mkstemp(prefix=name + '.', suffix=override_filename_ext_in_uri)
os.close(fd)
local_csv = Path(local_csv)
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload
local_csv.unlink()
try:
local_csv.unlink()
except Exception:
pass
return
self._last_artifacts_upload[name] = current_sha2
# If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'):
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
logger.report_image(title='artifacts', series=name, local_path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
return
# Find our artifact
@ -290,7 +442,9 @@ class Artifacts(object):
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
uri = self._upload_local_file(local_csv, name, delete_after_upload=True,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
# update task artifacts
with self._task_edit_lock:
@ -312,7 +466,9 @@ class Artifacts(object):
self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
def _upload_local_file(self, local_file, name, delete_after_upload=False,
override_filename=None,
override_filename_ext=None):
"""
Upload local file and return uri of the uploaded file (uploading in the background)
"""
@ -323,7 +479,8 @@ class Artifacts(object):
image_data=None, upload_uri=upload_uri,
local_image_path=local_file.as_posix(),
delete_after_upload=delete_after_upload,
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
override_filename=override_filename,
override_filename_ext=override_filename_ext,
override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri)