Add full artifacts support

This commit is contained in:
allegroai 2019-09-13 17:09:24 +03:00
parent 27ca36687a
commit d7bdc746b8
10 changed files with 397 additions and 72 deletions

38
examples/artifacts_toy.py Normal file
View File

@ -0,0 +1,38 @@
from time import sleep
import pandas as pd
import numpy as np
from PIL import Image
from trains import Task
task = Task.init('examples', 'artifacts toy')
df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
'num_wings': [2, 0, 0, 0],
'num_specimen_seen': [10, 2, 1, 8]},
index=['falcon', 'dog', 'spider', 'fish'])
# register Pandas object as artifact to watch
# (it will be monitored in the background and automatically synced and uploaded)
task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69})
# change the artifact object
df.sample(frac=0.5, replace=True, random_state=1)
# or access it from anywhere using the Task
Task.current_task().artifacts['train'].sample(frac=0.5, replace=True, random_state=1)
# add and upload local file artifact
task.upload_artifact('local file', artifact_object='samples/dancing.jpg')
# add and upload dictionary stored as JSON)
task.upload_artifact('dictionary', df.to_dict())
# add and upload Numpy Object (stored as .npz file)
task.upload_artifact('Numpy Eye', np.eye(100, 100))
# add and upload Image (stored as .png file)
im = Image.open('samples/dancing.jpg')
task.upload_artifact('pillow_image', im)
# do something
sleep(1.)
print(df)
# we are done
print('Done')

View File

@ -186,10 +186,12 @@ class UploadEvent(MetricsEventAdapter):
self._count = self._get_metric_count(metric, variant)
if not image_file_history_size:
image_file_history_size = self._image_file_history_size
if image_file_history_size < 1:
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._filename = kwargs.pop('override_filename', None)
if not self._filename:
if image_file_history_size < 1:
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._upload_uri = upload_uri
self._delete_after_upload = delete_after_upload
@ -198,6 +200,9 @@ class UploadEvent(MetricsEventAdapter):
image_format = self._format.lower() if self._image_data is not None else \
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None)
super(UploadEvent, self).__init__(metric, variant, iter=iter, **kwargs)
@classmethod
@ -273,10 +278,12 @@ class UploadEvent(MetricsEventAdapter):
delete_local_file=local_file if self._delete_after_upload else None,
)
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix=None):
e_storage_uri = self._upload_uri or storage_uri
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
filename = self._upload_filename
if self._override_storage_key_prefix or not storage_key_prefix:
storage_key_prefix = self._override_storage_key_prefix
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
return key, url

View File

@ -6,6 +6,7 @@ from tempfile import mkstemp
import six
from pathlib2 import Path
from ..backend_api import Session
from ..backend_api.services import models
from .base import IdObjectBase
from .util import make_message
@ -185,18 +186,21 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
design = self._wrap_design(design) if design else self.data.design
name = name or self.data.name
comment = comment or self.data.comment
tags = tags or self.data.tags
labels = labels or self.data.labels
task = task_id or self.data.task
project = project_id or self.data.project
parent = parent_id or self.data.parent
if tags:
extra = {'system_tags': tags or self.data.system_tags} \
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
else:
extra = {}
self.send(models.EditRequest(
model=self.id,
uri=uri,
name=name,
comment=comment,
tags=tags,
labels=labels,
design=design,
task=task,
@ -204,6 +208,27 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
parent=parent,
framework=framework or self.data.framework,
iteration=iteration,
**extra
))
self.reload()
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
uri=None, framework=None, iteration=None):
if tags:
extra = {'system_tags': tags or self.data.system_tags} \
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
else:
extra = {}
self.send(models.EditRequest(
model=self.id,
uri=uri,
name=name,
comment=comment,
labels=labels,
design=self._wrap_design(design) if design else None,
framework=framework,
iteration=iteration,
**extra
))
self.reload()
@ -239,7 +264,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if cb:
cb(model_file)
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback)
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename,
cb=callback)
return uri
else:
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
@ -264,12 +290,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if self._data:
name = name or self.data.name
comment = comment or self.data.comment
tags = tags or self.data.tags
tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags)
uri = (uri or self.data.uri) if not override_model_id else None
if tags:
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
else:
extra = {}
res = self.send(
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
override_model_id=override_model_id))
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
override_model_id=override_model_id, **extra))
if self.id is None:
# update the model id. in case it was just created, this will trigger a reload of the model object
self.id = res.response.id
@ -295,8 +325,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
else:
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
override_model_id=override_model_id, iteration=iteration))
if tags:
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
else:
extra = {}
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
override_model_id=override_model_id, iteration=iteration,
**extra))
return uri
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
@ -337,7 +372,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
@property
def tags(self):
return self.data.tags
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
@property
def locked(self):
@ -402,18 +437,20 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
data = self.data
assert isinstance(data, models.Model)
parent = self.id if child else None
extra = {'system_tags': tags or data.system_tags} \
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
req = models.CreateRequest(
uri=data.uri,
name=name,
labels=data.labels,
comment=comment or data.comment,
tags=tags or data.tags,
framework=data.framework,
design=data.design,
ready=ready,
project=data.project,
parent=parent,
task=task,
**extra
)
res = self.send(req)
return res.response.id

View File

@ -6,6 +6,7 @@ from enum import Enum
from threading import RLock, Thread
import six
from six.moves.urllib.parse import quote
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
@ -96,6 +97,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# this is an existing task, let's try to verify stuff
self._validate()
self._project_name = (self.project, project_name)
if running_remotely() or DevWorker.report_stdout:
log_to_backend = False
self._log_to_backend = log_to_backend
@ -223,14 +226,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
project_id = get_or_create_project(self, project_name, created_msg)
tags = [self._development_tag] if not running_remotely() else []
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=tasks.TaskTypeEnum(task_type.value),
comment=created_msg,
project=project_id,
input={'view': {}},
tags=tags,
**extra_properties
)
res = self.send(req)
@ -369,7 +372,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._reporter
def _get_output_destination_suffix(self, extra_path=None):
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
return '/'.join(quote(x, safe='[]{}()$^,.; -_+-=') for x in
(self.get_project_name(), '%s.%s' % (self.name, self.data.id), extra_path) if x)
def _reload(self):
""" Reload the task object from the backend """
@ -427,9 +431,10 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
"""
Update the task's output model.
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
providing a file://path/to/file URI)
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
providing a file://path/to/file URI)
:param model_uri: URI for the updated model weights file
:type model_uri: str
:param name: Optional updated model name
@ -446,8 +451,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self, model_file, name=None, comment=None, tags=None, async_enable=False, cb=None, iteration=None):
"""
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
task's output.destination property or call setup_upload()), than the model object associated with the task is
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
task's output.destination property or call setup_upload()), than the model object associated with the task is
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
:param model_file: Path to the updated model weights file
:type model_file: str
:param name: Optional updated model name
@ -632,6 +638,21 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
execution.model_labels = enumeration
self._edit(execution=execution)
def set_artifacts(self, artifacts_list=None):
"""
List of artifacts (tasks.Artifact) to update the task
:param list artifacts_list: list of artifacts (type tasks.Artifact)
"""
if not Session.check_min_api_version('2.3'):
return False
if not (isinstance(artifacts_list, (list, tuple))
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
raise ValueError('Expected artifacts to [tasks.Artifacts]')
execution = self.data.execution
execution.artifacts = artifacts_list
self._edit(execution=execution)
def _set_model_design(self, design=None):
execution = self.data.execution
if design is not None:
@ -677,7 +698,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if self.project is None:
return None
if self._project_name and self._project_name[0] == self.project:
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
return self._project_name[1]
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
@ -689,8 +710,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def get_tags(self):
return self._get_task_property("tags")
def set_system_tags(self, tags):
assert isinstance(tags, (list, tuple))
if Session.check_min_api_version('2.3'):
self._set_task_property("system_tags", tags)
self._edit(system_tags=self.data.system_tags)
else:
self._set_task_property("tags", tags)
self._edit(tags=self.data.tags)
def set_tags(self, tags):
assert isinstance(tags, (list, tuple))
if not Session.check_min_api_version('2.3'):
# not supported
return
self._set_task_property("tags", tags)
self._edit(tags=self.data.tags)

View File

@ -25,7 +25,7 @@ def get_or_create_project(session, project_name, description=None):
return res.response.id
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True):
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):
if not results:
if not raise_on_error:
return None
@ -38,6 +38,13 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
if len(results) > 1:
log.warn('More than one {entity} found when searching for `{query}`'
' (showing first {show_results} {entity}s follow)'.format(**locals()))
if sort_by_date:
# sort results based on timestamp and return the newest one
if hasattr(results[0], 'last_update'):
results = sorted(results, key=lambda x: int(x.last_update.strftime('%s')), reverse=True)
elif hasattr(results[0], 'created'):
results = sorted(results, key=lambda x: int(x.created.strftime('%s')), reverse=True)
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))

View File

@ -1,18 +1,31 @@
import hashlib
import json
import logging
import mimetypes
import os
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from tempfile import mkdtemp
from threading import Thread, Event
from tempfile import mkdtemp, mkstemp
from threading import Thread, Event, RLock
from time import time
import numpy as np
import six
from pathlib2 import Path
from PIL import Image
from ..backend_interface.metrics.events import UploadEvent
from ..backend_api import Session
from ..debugging.log import LoggerRoot
from ..backend_api.services import tasks
try:
import pandas as pd
except ImportError:
pd = None
try:
import numpy as np
except ImportError:
np = None
class Artifacts(object):
@ -22,6 +35,7 @@ class Artifacts(object):
_compression = 'gzip'
# hashing constants
_hash_block_size = 65536
_pd_artifact_type = 'data-audit-table'
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
@ -33,7 +47,7 @@ class Artifacts(object):
def __setitem__(self, key, value):
# check that value is of type pandas
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
if pd and isinstance(value, pd.DataFrame):
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
if self._artifacts_manager:
@ -72,6 +86,9 @@ class Artifacts(object):
self._thread_pool = ThreadPool()
self._summary = ''
self._temp_folder = []
self._task_artifact_list = []
self._task_edit_lock = RLock()
self._storage_prefix = None
def register_artifact(self, name, artifact, metadata=None):
# currently we support pandas.DataFrame (which we will upload as csv.gz)
@ -86,6 +103,94 @@ class Artifacts(object):
self._unregister_request.add(name)
self.flush()
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
if not Session.check_min_api_version('2.3'):
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
'please upgrade to the latest server version')
return False
if name in self._artifacts_dict:
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
artifact_type_data = tasks.ArtifactTypeData()
use_filename_in_uri = True
if np and isinstance(artifact_object, np.ndarray):
artifact_type = 'numpy'
artifact_type_data.content_type = 'application/numpy'
artifact_type_data.preview = str(artifact_object.__repr__())
fd, local_filename = mkstemp(suffix='.npz')
os.close(fd)
np.savez_compressed(local_filename, **{name: artifact_object})
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, Image.Image):
artifact_type = 'image'
artifact_type_data.content_type = 'image/png'
desc = str(artifact_object.__repr__())
artifact_type_data.preview = desc[1:desc.find(' at ')]
fd, local_filename = mkstemp(suffix='.png')
os.close(fd)
artifact_object.save(local_filename)
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, dict):
artifact_type = 'JSON'
artifact_type_data.content_type = 'application/json'
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
fd, local_filename = mkstemp(suffix='.json')
os.write(fd, bytes(preview.encode()))
os.close(fd)
artifact_type_data.preview = preview
delete_after_upload = True
use_filename_in_uri = False
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
if isinstance(artifact_object, Path):
artifact_object = artifact_object.as_posix()
artifact_type = 'custom'
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
local_filename = artifact_object
else:
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
# remove from existing list, if exists
for artifact in self._task_artifact_list:
if artifact.key == name:
if artifact.type == self._pd_artifact_type:
raise ValueError("Artifact of name {} already registered, "
"use register_artifact instead".format(name))
self._task_artifact_list.remove(artifact)
break
# check that the file to upload exists
local_filename = Path(local_filename).absolute()
if not local_filename.exists() or not local_filename.is_file():
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
local_filename.as_posix()))
return False
file_hash, _ = self.sha256sum(local_filename.as_posix())
timestamp = int(time())
file_size = local_filename.stat().st_size
uri = self._upload_local_file(local_filename, name,
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
artifact = tasks.Artifact(key=name, type=artifact_type,
uri=uri,
content_size=file_size,
hash=file_hash,
timestamp=timestamp,
type_data=artifact_type_data,
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
# update task artifacts
with self._task_edit_lock:
self._task_artifact_list.append(artifact)
self._task.set_artifacts(self._task_artifact_list)
return True
def flush(self):
# start the thread if it hasn't already:
self._start()
@ -119,19 +224,20 @@ class Artifacts(object):
while not self._exit_flag:
self._flush_event.wait(self._flush_frequency_sec)
self._flush_event.clear()
try:
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
self._upload_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
artifact_keys = list(self._artifacts_dict.keys())
for name in artifact_keys:
try:
self._upload_data_audit_artifacts(name)
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
# create summary
self._summary = self._get_statistics()
def _upload_artifacts(self, name):
def _upload_data_audit_artifacts(self, name):
logger = self._task.get_logger()
artifact = self._artifacts_dict.get(name)
pd_artifact = self._artifacts_dict.get(name)
pd_metadata = self._artifacts_dict.get_metadata(name)
# remove from artifacts watch list
if name in self._unregister_request:
@ -141,15 +247,15 @@ class Artifacts(object):
pass
self._artifacts_dict.unregister_artifact(name)
if artifact is None:
if pd_artifact is None:
return
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
if local_csv.exists():
# we are still uploading... get another temp folder
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
if name in self._last_artifacts_upload:
previous_sha2 = self._last_artifacts_upload[name]
if previous_sha2 == current_sha2:
@ -157,19 +263,75 @@ class Artifacts(object):
local_csv.unlink()
return
self._last_artifacts_upload[name] = current_sha2
# now upload and delete at the end.
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
def _get_statistics(self):
# If old trains-server, upload as debug image
if not Session.check_min_api_version('2.3'):
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
delete_after_upload=True, iteration=self._task.get_last_iteration(),
max_image_history=2)
return
# Find our artifact
artifact = None
for an_artifact in self._task_artifact_list:
if an_artifact.key == name:
artifact = an_artifact
break
file_size = local_csv.stat().st_size
# upload file
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
# update task artifacts
with self._task_edit_lock:
if not artifact:
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
self._task_artifact_list.append(artifact)
artifact_type_data = tasks.ArtifactTypeData()
artifact_type_data.data_hash = current_sha2
artifact_type_data.content_type = "text/csv"
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
artifact.type_data = artifact_type_data
artifact.uri = uri
artifact.content_size = file_size
artifact.hash = file_sha2
artifact.timestamp = int(time())
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
self._task.set_artifacts(self._task_artifact_list)
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
"""
Upload local file and return uri of the uploaded file (uploading in the background)
"""
upload_uri = self._task.get_logger().get_default_upload_destination()
if not isinstance(local_file, Path):
local_file = Path(local_file)
ev = UploadEvent(metric='artifacts', variant=name,
image_data=None, upload_uri=upload_uri,
local_image_path=local_file.as_posix(),
delete_after_upload=delete_after_upload,
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
override_storage_key_prefix=self._get_storage_uri_prefix())
_, uri = ev.get_target_full_upload_uri(upload_uri)
# send for upload
self._task.reporter._report(ev)
return uri
def _get_statistics(self, artifacts_dict=None):
summary = ''
artifacts_dict = artifacts_dict or self._artifacts_dict
thread_pool = ThreadPool()
try:
# build hash row sets
artifacts_summary = []
for a_name, a_df in self._artifacts_dict.items():
for a_name, a_df in artifacts_dict.items():
if not pd or not isinstance(a_df, pd.DataFrame):
continue
@ -206,16 +368,30 @@ class Artifacts(object):
return new_temp
return self._temp_folder[0]
def _get_storage_uri_prefix(self):
if not self._storage_prefix:
self._storage_prefix = self._task._get_output_destination_suffix()
return self._storage_prefix
@staticmethod
def sha256sum(filename, skip_header=0):
# create sha2 of the file, notice we skip the header of the file (32 bytes)
# because sometimes that is the only change
h = hashlib.sha256()
file_hash = hashlib.sha256()
b = bytearray(Artifacts._hash_block_size)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
# skip header
f.read(skip_header)
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()
try:
with open(filename, 'rb', buffering=0) as f:
# skip header
if skip_header:
file_hash.update(f.read(skip_header))
for n in iter(lambda: f.readinto(mv), 0):
h.update(mv[:n])
if skip_header:
file_hash.update(mv[:n])
except Exception as e:
LoggerRoot.get_base_logger().warning(str(e))
return None, None
return h.hexdigest(), file_hash.hexdigest() if skip_header else None

View File

@ -109,7 +109,7 @@ class PatchedMatplotlib:
# update api version
from ..backend_api import Session
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
# create plotly renderer
try:

View File

@ -572,7 +572,7 @@ class Logger(object):
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
@ -619,7 +619,7 @@ class Logger(object):
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
@ -664,7 +664,7 @@ class Logger(object):
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
@ -705,6 +705,19 @@ class Logger(object):
self._default_upload_destination = uri
def get_default_upload_destination(self):
"""
Get the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
"""
return self._default_upload_destination or self._task._get_default_report_storage_uri()
def flush(self):
"""
Flush cached reports and console outputs to backend.

View File

@ -6,6 +6,8 @@ from tempfile import mkdtemp, mkstemp
import pyparsing
import six
from .backend_api import Session
from .backend_api.services import models
from pathlib2 import Path
from pyhocon import ConfigFactory, HOCONConverter
@ -276,10 +278,13 @@ class BaseModel(object):
def _set_package_tag(self):
if self._package_tag not in self.tags:
self.tags.append(self._package_tag)
self._get_base_model().update(tags=self.tags)
self._get_base_model().edit(tags=self.tags)
@staticmethod
def _config_dict_to_text(config):
# if already string return as is
if isinstance(config, six.string_types):
return config
if not isinstance(config, dict):
raise ValueError("Model configuration only supports dictionary objects")
try:
@ -372,10 +377,12 @@ class InputModel(BaseModel):
weights_url = StorageHelper.conform_url(weights_url)
if not weights_url:
raise ValueError("Please provide a valid weights_url parameter")
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
result = _Model._get_default_session().send(models.GetAllRequest(
uri=[weights_url],
only_fields=["id", "name"],
tags=["-" + ARCHIVED_TAG]
**extra
))
if result.response.models:
@ -910,7 +917,7 @@ class OutputModel(BaseModel):
if self.id:
# update the model object (this will happen if we resumed a training task)
result = self._get_force_base_model().update(design=config_text, task_id=self._task.id)
result = self._get_force_base_model().edit(design=config_text)
else:
self._floating_data.design = _Model._wrap_design(config_text)
result = Waitable()
@ -936,7 +943,7 @@ class OutputModel(BaseModel):
if self.id:
# update the model object (this will happen if we resumed a training task)
result = self._get_force_base_model().update(labels=labels, task_id=self._task.id)
result = self._get_force_base_model().edit(labels=labels)
else:
self._floating_data.labels = labels
result = Waitable()

View File

@ -309,7 +309,7 @@ class Task(_Task):
If project is None, and the main execution task is initialized (Task.init), its project will be used.
If project is provided but doesn't exist, it will be created.
:param task_type: Task type to be created. (default: "training")
Optional Task types are: "training" / "testing" / "dataset_import" / "annotation" / "annotation_manual"
Optional Task types are: "training" / "testing" / "dataset_import" / "annotation" / "annotation_manual"
:return: Task() object
"""
if not project_name:
@ -390,8 +390,10 @@ class Task(_Task):
task_id=default_task_id,
log_to_backend=True,
)
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
or task.output_model_id or (ARCHIVED_TAG in task_tags)
or (cls._development_tag not in task_tags)):
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
# If the task is archived, or already has an output model,
# we shouldn't use it in development mode either
@ -401,7 +403,7 @@ class Task(_Task):
# reset the task, so we can update it
task.reset(set_started_on_success=False, force=False)
# set development tags
task.set_tags([cls._development_tag])
task.set_system_tags([cls._development_tag])
# clear task parameters, they are not cleared by the Task reset
task.set_parameters({}, __update=False)
# clear the comment, it is not cleared on reset
@ -410,6 +412,7 @@ class Task(_Task):
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
task.set_model_config(config_text='')
task.set_model_label_enumeration({})
task.set_artifacts([])
except (Exception, ValueError):
# we failed reusing task, create a new one
@ -480,7 +483,7 @@ class Task(_Task):
if value and value != self.storage_uri:
from .storage.helper import StorageHelper
helper = StorageHelper.get(value)
helper.check_write_permissions()
helper.check_write_permissions(value)
self.storage_uri = value
@property
@ -658,20 +661,24 @@ class Task(_Task):
"""
self._artifacts_manager.unregister_artifact(name=name)
def upload_artifact(self, name, artifact_object=None, artifact_file=None, metadata=None):
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
"""
Add static artifact to Task. Artifact file/object will be uploaded in the background
Raise ValueError if artifact_object is not supported
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
:param object artifact_object: Artifact object to upload. Currently supports Numpy, PIL.Image.
Numpy will be stored as .npz, and Image as .png file.
Use None if uploading a file directly with 'artifact_file'.
:param str artifact_file: path to artifact file to upload. None means not applicable.
Notice you wither artifact object or artifact_file
:param object artifact_object: Artifact object to upload. Currently supports:
- string / pathlib2.Path are treated as path to artifact file to upload
- dict will be stored as .json,
- numpy.ndarray will be stored as .npz,
- PIL.Image will be stored to .png file and uploaded
:param dict metadata: Simple key/value dictionary to store on the artifact
:return: True if artifact is supported
:param bool delete_after_upload: If True local artifact will be deleted
(only applies if artifact_object is a local file)
:return: True if artifact will be uploaded
"""
raise ValueError("Not implemented yet")
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
metadata=metadata, delete_after_upload=delete_after_upload)
def is_current_task(self):
"""