Add Full Artifacts support, including local files/folders, pandas, numpy, images. requires train-server 0.11 and above

This commit is contained in:
allegroai 2019-09-23 18:41:53 +03:00
parent 4f1eeb49c6
commit b1d3fc9694
2 changed files with 187 additions and 29 deletions

View File

@ -17,8 +17,8 @@ df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69}) task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69})
# change the artifact object # change the artifact object
df.sample(frac=0.5, replace=True, random_state=1) df.sample(frac=0.5, replace=True, random_state=1)
# or access it from anywhere using the Task # or access it from anywhere using the Task's get_registered_artifacts()
Task.current_task().artifacts['train'].sample(frac=0.5, replace=True, random_state=1) Task.current_task().get_registered_artifacts()['train'].sample(frac=0.5, replace=True, random_state=1)
# add and upload pandas.DataFrame (onetime snapshot of the object) # add and upload pandas.DataFrame (onetime snapshot of the object)
task.upload_artifact('Pandas', artifact_object=df) task.upload_artifact('Pandas', artifact_object=df)
@ -32,6 +32,7 @@ task.upload_artifact('Numpy Eye', np.eye(100, 100))
im = Image.open('samples/dancing.jpg') im = Image.open('samples/dancing.jpg')
task.upload_artifact('pillow_image', im) task.upload_artifact('pillow_image', im)
# do something # do something
sleep(1.) sleep(1.)
print(df) print(df)

View File

@ -1,14 +1,16 @@
import hashlib import hashlib
import json import json
import logging
import mimetypes import mimetypes
import os import os
from zipfile import ZipFile, ZIP_DEFLATED
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp, mkstemp from tempfile import mkdtemp, mkstemp
from threading import Thread, Event, RLock from threading import Thread, Event, RLock
from time import time from time import time
import humanfriendly
import six import six
from pathlib2 import Path from pathlib2 import Path
from PIL import Image from PIL import Image
@ -28,6 +30,103 @@ except ImportError:
np = None np = None
class Artifact(object):
"""
Read-Only Artifact object
"""
@property
def url(self):
"""
:return: url of uploaded artifact
"""
return self._url
@property
def name(self):
"""
:return: name of artifact
"""
return self._name
@property
def size(self):
"""
:return: size in bytes of artifact
"""
return self._size
@property
def type(self):
"""
:return: type (str) of of artifact
"""
return self._type
@property
def mode(self):
"""
:return: mode (str) of of artifact. either "input" or "output"
"""
return self._mode
@property
def hash(self):
"""
:return: SHA2 hash (str) of of artifact content.
"""
return self._hash
@property
def timestamp(self):
"""
:return: Timestamp (datetime) of uploaded artifact.
"""
return self._timestamp
@property
def metadata(self):
"""
:return: Key/Value dictionary attached to artifact.
"""
return self._metadata
@property
def preview(self):
"""
:return: string (str) representation of the artifact.
"""
return self._preview
def __init__(self, artifact_api_object):
"""
construct read-only object from api artifact object
:param tasks.Artifact artifact_api_object:
"""
self._name = artifact_api_object.key
self._size = artifact_api_object.content_size
self._type = artifact_api_object.type
self._mode = artifact_api_object.mode
self._url = artifact_api_object.uri
self._hash = artifact_api_object.hash
self._timestamp = datetime.fromtimestamp(artifact_api_object.timestamp)
self._metadata = dict(artifact_api_object.display_data) if artifact_api_object.display_data else {}
self._preview = artifact_api_object.type_data.preview if artifact_api_object.type_data else None
def get_local_copy(self):
"""
:return: a local path to a downloaded copy of the artifact
"""
from trains.storage.helper import StorageHelper
return StorageHelper.get_local_copy(self.url)
def __repr__(self):
return str({'name': self.name, 'size': self.size, 'type': self.type, 'mode': self.mode, 'url': self.url,
'hash': self.hash, 'timestamp': self.timestamp,
'metadata': self.metadata, 'preview': self.preview,})
class Artifacts(object): class Artifacts(object):
_flush_frequency_sec = 300. _flush_frequency_sec = 300.
# notice these two should match # notice these two should match
@ -66,7 +165,7 @@ class Artifacts(object):
return self.artifact_metadata.get(name) return self.artifact_metadata.get(name)
@property @property
def artifacts(self): def registered_artifacts(self):
return self._artifacts_dict return self._artifacts_dict
@property @property
@ -113,51 +212,98 @@ class Artifacts(object):
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name)) raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData() artifact_type_data = tasks.ArtifactTypeData()
use_filename_in_uri = True override_filename_in_uri = None
override_filename_ext_in_uri = None
if np and isinstance(artifact_object, np.ndarray): if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy' artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy' artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__()) artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix='.npz') override_filename_ext_in_uri = '.npz'
override_filename_in_uri = name+override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd) os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object}) np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True delete_after_upload = True
use_filename_in_uri = False
elif pd and isinstance(artifact_object, pd.DataFrame): elif pd and isinstance(artifact_object, pd.DataFrame):
artifact_type = 'pandas' artifact_type = 'pandas'
artifact_type_data.content_type = 'text/csv' artifact_type_data.content_type = 'text/csv'
artifact_type_data.preview = str(artifact_object.__repr__()) artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix=self._save_format) override_filename_ext_in_uri = self._save_format
override_filename_in_uri = name
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd) os.close(fd)
artifact_object.to_csv(local_filename, compression=self._compression) artifact_object.to_csv(local_filename, compression=self._compression)
delete_after_upload = True delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, Image.Image): elif isinstance(artifact_object, Image.Image):
artifact_type = 'image' artifact_type = 'image'
artifact_type_data.content_type = 'image/png' artifact_type_data.content_type = 'image/png'
desc = str(artifact_object.__repr__()) desc = str(artifact_object.__repr__())
artifact_type_data.preview = desc[1:desc.find(' at ')] artifact_type_data.preview = desc[1:desc.find(' at ')]
fd, local_filename = mkstemp(suffix='.png') override_filename_ext_in_uri = '.png'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.close(fd) os.close(fd)
artifact_object.save(local_filename) artifact_object.save(local_filename)
delete_after_upload = True delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, dict): elif isinstance(artifact_object, dict):
artifact_type = 'JSON' artifact_type = 'JSON'
artifact_type_data.content_type = 'application/json' artifact_type_data.content_type = 'application/json'
preview = json.dumps(artifact_object, sort_keys=True, indent=4) preview = json.dumps(artifact_object, sort_keys=True, indent=4)
fd, local_filename = mkstemp(suffix='.json') override_filename_ext_in_uri = '.json'
override_filename_in_uri = name + override_filename_ext_in_uri
fd, local_filename = mkstemp(prefix=name+'.', suffix=override_filename_ext_in_uri)
os.write(fd, bytes(preview.encode())) os.write(fd, bytes(preview.encode()))
os.close(fd) os.close(fd)
artifact_type_data.preview = preview artifact_type_data.preview = preview
delete_after_upload = True delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path): elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
if isinstance(artifact_object, Path): # check if single file
if isinstance(artifact_object, six.string_types):
artifact_object = Path(artifact_object)
artifact_object.expanduser().absolute()
create_zip_file = not artifact_object.is_file()
if artifact_object.is_dir():
# change to wildcard
artifact_object = artifact_object / '*'
if create_zip_file:
folder = Path('').joinpath(*artifact_object.parts[:-1])
wildcard = artifact_object.parts[-1]
files = list(Path(folder).rglob(wildcard))
override_filename_ext_in_uri = '.zip'
override_filename_in_uri = folder.parts[-1] + override_filename_ext_in_uri
fd, zip_file = mkstemp(prefix=folder.parts[-1]+'.', suffix=override_filename_ext_in_uri)
try:
artifact_type_data.content_type = 'application/zip'
artifact_type_data.preview = 'Archive content {}:\n'.format(artifact_object.as_posix())
with ZipFile(zip_file, 'w', allowZip64=True, compression=ZIP_DEFLATED) as zf:
for filename in sorted(files):
if filename.is_file():
relative_file_name = filename.relative_to(folder).as_posix()
artifact_type_data.preview += '{} - {}\n'.format(
relative_file_name, humanfriendly.format_size(filename.stat().st_size))
zf.write(filename.as_posix(), arcname=relative_file_name)
except Exception as e:
# failed uploading folder:
LoggerRoot.get_base_logger().warning('Exception {}\nFailed zipping artifact folder {}'.format(
folder, e))
return None
finally:
os.close(fd)
artifact_object = zip_file
artifact_type = 'zip'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
delete_after_upload = True
else:
override_filename_in_uri = artifact_object.parts[-1]
artifact_object = artifact_object.as_posix() artifact_object = artifact_object.as_posix()
artifact_type = 'custom' artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0] artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object local_filename = artifact_object
else: else:
raise ValueError("Artifact type {} not supported".format(type(artifact_object))) raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
@ -183,7 +329,9 @@ class Artifacts(object):
file_size = local_filename.stat().st_size file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name, uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri) delete_after_upload=delete_after_upload,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type, artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri, uri=uri,
@ -259,25 +407,29 @@ class Artifacts(object):
if pd_artifact is None: if pd_artifact is None:
return return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute() override_filename_ext_in_uri = self._save_format
if local_csv.exists(): override_filename_in_uri = name
# we are still uploading... get another temp folder fd, local_csv = mkstemp(prefix=name + '.', suffix=override_filename_ext_in_uri)
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute() os.close(fd)
local_csv = Path(local_csv)
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression) pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32) current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload: if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name] previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2: if previous_sha2 == current_sha2:
# nothing to do, we can skip the upload # nothing to do, we can skip the upload
local_csv.unlink() try:
local_csv.unlink()
except Exception:
pass
return return
self._last_artifacts_upload[name] = current_sha2 self._last_artifacts_upload[name] = current_sha2
# If old trains-server, upload as debug image # If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'): if not Session.check_min_api_version('2.3'):
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(), logger.report_image(title='artifacts', series=name, local_path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(), delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2) max_image_history=2)
return return
# Find our artifact # Find our artifact
@ -290,7 +442,9 @@ class Artifacts(object):
file_size = local_csv.stat().st_size file_size = local_csv.stat().st_size
# upload file # upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True) uri = self._upload_local_file(local_csv, name, delete_after_upload=True,
override_filename=override_filename_in_uri,
override_filename_ext=override_filename_ext_in_uri)
# update task artifacts # update task artifacts
with self._task_edit_lock: with self._task_edit_lock:
@ -312,7 +466,9 @@ class Artifacts(object):
self._task.set_artifacts(self._task_artifact_list) self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True): def _upload_local_file(self, local_file, name, delete_after_upload=False,
override_filename=None,
override_filename_ext=None):
""" """
Upload local file and return uri of the uploaded file (uploading in the background) Upload local file and return uri of the uploaded file (uploading in the background)
""" """
@ -323,7 +479,8 @@ class Artifacts(object):
image_data=None, upload_uri=upload_uri, image_data=None, upload_uri=upload_uri,
local_image_path=local_file.as_posix(), local_image_path=local_file.as_posix(),
delete_after_upload=delete_after_upload, delete_after_upload=delete_after_upload,
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None, override_filename=override_filename,
override_filename_ext=override_filename_ext,
override_storage_key_prefix=self._get_storage_uri_prefix()) override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri) _, uri = ev.get_target_full_upload_uri(upload_uri)