mirror of
https://github.com/clearml/clearml
synced 2025-05-07 14:24:31 +00:00
Add full artifacts support
This commit is contained in:
parent
27ca36687a
commit
d7bdc746b8
38
examples/artifacts_toy.py
Normal file
38
examples/artifacts_toy.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from trains import Task
|
||||||
|
|
||||||
|
task = Task.init('examples', 'artifacts toy')
|
||||||
|
|
||||||
|
df = pd.DataFrame({'num_legs': [2, 4, 8, 0],
|
||||||
|
'num_wings': [2, 0, 0, 0],
|
||||||
|
'num_specimen_seen': [10, 2, 1, 8]},
|
||||||
|
index=['falcon', 'dog', 'spider', 'fish'])
|
||||||
|
|
||||||
|
# register Pandas object as artifact to watch
|
||||||
|
# (it will be monitored in the background and automatically synced and uploaded)
|
||||||
|
task.register_artifact('train', df, metadata={'counting': 'legs', 'max legs': 69})
|
||||||
|
# change the artifact object
|
||||||
|
df.sample(frac=0.5, replace=True, random_state=1)
|
||||||
|
# or access it from anywhere using the Task
|
||||||
|
Task.current_task().artifacts['train'].sample(frac=0.5, replace=True, random_state=1)
|
||||||
|
|
||||||
|
# add and upload local file artifact
|
||||||
|
task.upload_artifact('local file', artifact_object='samples/dancing.jpg')
|
||||||
|
# add and upload dictionary stored as JSON)
|
||||||
|
task.upload_artifact('dictionary', df.to_dict())
|
||||||
|
# add and upload Numpy Object (stored as .npz file)
|
||||||
|
task.upload_artifact('Numpy Eye', np.eye(100, 100))
|
||||||
|
# add and upload Image (stored as .png file)
|
||||||
|
im = Image.open('samples/dancing.jpg')
|
||||||
|
task.upload_artifact('pillow_image', im)
|
||||||
|
|
||||||
|
# do something
|
||||||
|
sleep(1.)
|
||||||
|
print(df)
|
||||||
|
|
||||||
|
# we are done
|
||||||
|
print('Done')
|
@ -186,6 +186,8 @@ class UploadEvent(MetricsEventAdapter):
|
|||||||
self._count = self._get_metric_count(metric, variant)
|
self._count = self._get_metric_count(metric, variant)
|
||||||
if not image_file_history_size:
|
if not image_file_history_size:
|
||||||
image_file_history_size = self._image_file_history_size
|
image_file_history_size = self._image_file_history_size
|
||||||
|
self._filename = kwargs.pop('override_filename', None)
|
||||||
|
if not self._filename:
|
||||||
if image_file_history_size < 1:
|
if image_file_history_size < 1:
|
||||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
|
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
|
||||||
else:
|
else:
|
||||||
@ -198,6 +200,9 @@ class UploadEvent(MetricsEventAdapter):
|
|||||||
image_format = self._format.lower() if self._image_data is not None else \
|
image_format = self._format.lower() if self._image_data is not None else \
|
||||||
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
|
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
|
||||||
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
|
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
|
||||||
|
|
||||||
|
self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None)
|
||||||
|
|
||||||
super(UploadEvent, self).__init__(metric, variant, iter=iter, **kwargs)
|
super(UploadEvent, self).__init__(metric, variant, iter=iter, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -273,10 +278,12 @@ class UploadEvent(MetricsEventAdapter):
|
|||||||
delete_local_file=local_file if self._delete_after_upload else None,
|
delete_local_file=local_file if self._delete_after_upload else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix):
|
def get_target_full_upload_uri(self, storage_uri, storage_key_prefix=None):
|
||||||
e_storage_uri = self._upload_uri or storage_uri
|
e_storage_uri = self._upload_uri or storage_uri
|
||||||
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
||||||
filename = self._upload_filename
|
filename = self._upload_filename
|
||||||
|
if self._override_storage_key_prefix or not storage_key_prefix:
|
||||||
|
storage_key_prefix = self._override_storage_key_prefix
|
||||||
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
|
key = '/'.join(x for x in (storage_key_prefix, self.metric, self.variant, filename.strip('/')) if x)
|
||||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||||
return key, url
|
return key, url
|
||||||
|
@ -6,6 +6,7 @@ from tempfile import mkstemp
|
|||||||
import six
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ..backend_api import Session
|
||||||
from ..backend_api.services import models
|
from ..backend_api.services import models
|
||||||
from .base import IdObjectBase
|
from .base import IdObjectBase
|
||||||
from .util import make_message
|
from .util import make_message
|
||||||
@ -185,18 +186,21 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
design = self._wrap_design(design) if design else self.data.design
|
design = self._wrap_design(design) if design else self.data.design
|
||||||
name = name or self.data.name
|
name = name or self.data.name
|
||||||
comment = comment or self.data.comment
|
comment = comment or self.data.comment
|
||||||
tags = tags or self.data.tags
|
|
||||||
labels = labels or self.data.labels
|
labels = labels or self.data.labels
|
||||||
task = task_id or self.data.task
|
task = task_id or self.data.task
|
||||||
project = project_id or self.data.project
|
project = project_id or self.data.project
|
||||||
parent = parent_id or self.data.parent
|
parent = parent_id or self.data.parent
|
||||||
|
if tags:
|
||||||
|
extra = {'system_tags': tags or self.data.system_tags} \
|
||||||
|
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
|
|
||||||
self.send(models.EditRequest(
|
self.send(models.EditRequest(
|
||||||
model=self.id,
|
model=self.id,
|
||||||
uri=uri,
|
uri=uri,
|
||||||
name=name,
|
name=name,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
tags=tags,
|
|
||||||
labels=labels,
|
labels=labels,
|
||||||
design=design,
|
design=design,
|
||||||
task=task,
|
task=task,
|
||||||
@ -204,6 +208,27 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
parent=parent,
|
parent=parent,
|
||||||
framework=framework or self.data.framework,
|
framework=framework or self.data.framework,
|
||||||
iteration=iteration,
|
iteration=iteration,
|
||||||
|
**extra
|
||||||
|
))
|
||||||
|
self.reload()
|
||||||
|
|
||||||
|
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||||
|
uri=None, framework=None, iteration=None):
|
||||||
|
if tags:
|
||||||
|
extra = {'system_tags': tags or self.data.system_tags} \
|
||||||
|
if hasattr(self.data, 'system_tags') else {'tags': tags or self.data.tags}
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
|
self.send(models.EditRequest(
|
||||||
|
model=self.id,
|
||||||
|
uri=uri,
|
||||||
|
name=name,
|
||||||
|
comment=comment,
|
||||||
|
labels=labels,
|
||||||
|
design=self._wrap_design(design) if design else None,
|
||||||
|
framework=framework,
|
||||||
|
iteration=iteration,
|
||||||
|
**extra
|
||||||
))
|
))
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
@ -239,7 +264,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
if cb:
|
if cb:
|
||||||
cb(model_file)
|
cb(model_file)
|
||||||
|
|
||||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback)
|
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename,
|
||||||
|
cb=callback)
|
||||||
return uri
|
return uri
|
||||||
else:
|
else:
|
||||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||||
@ -264,12 +290,16 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
if self._data:
|
if self._data:
|
||||||
name = name or self.data.name
|
name = name or self.data.name
|
||||||
comment = comment or self.data.comment
|
comment = comment or self.data.comment
|
||||||
tags = tags or self.data.tags
|
tags = tags or (self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags)
|
||||||
uri = (uri or self.data.uri) if not override_model_id else None
|
uri = (uri or self.data.uri) if not override_model_id else None
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
res = self.send(
|
res = self.send(
|
||||||
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
|
||||||
override_model_id=override_model_id))
|
override_model_id=override_model_id, **extra))
|
||||||
if self.id is None:
|
if self.id is None:
|
||||||
# update the model id. in case it was just created, this will trigger a reload of the model object
|
# update the model id. in case it was just created, this will trigger a reload of the model object
|
||||||
self.id = res.response.id
|
self.id = res.response.id
|
||||||
@ -295,8 +325,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
else:
|
else:
|
||||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
|
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
|
||||||
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
|
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
|
||||||
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
if tags:
|
||||||
override_model_id=override_model_id, iteration=iteration))
|
extra = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
|
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment,
|
||||||
|
override_model_id=override_model_id, iteration=iteration,
|
||||||
|
**extra))
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
|
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
|
||||||
@ -337,7 +372,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self):
|
def tags(self):
|
||||||
return self.data.tags
|
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def locked(self):
|
def locked(self):
|
||||||
@ -402,18 +437,20 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
data = self.data
|
data = self.data
|
||||||
assert isinstance(data, models.Model)
|
assert isinstance(data, models.Model)
|
||||||
parent = self.id if child else None
|
parent = self.id if child else None
|
||||||
|
extra = {'system_tags': tags or data.system_tags} \
|
||||||
|
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
|
||||||
req = models.CreateRequest(
|
req = models.CreateRequest(
|
||||||
uri=data.uri,
|
uri=data.uri,
|
||||||
name=name,
|
name=name,
|
||||||
labels=data.labels,
|
labels=data.labels,
|
||||||
comment=comment or data.comment,
|
comment=comment or data.comment,
|
||||||
tags=tags or data.tags,
|
|
||||||
framework=data.framework,
|
framework=data.framework,
|
||||||
design=data.design,
|
design=data.design,
|
||||||
ready=ready,
|
ready=ready,
|
||||||
project=data.project,
|
project=data.project,
|
||||||
parent=parent,
|
parent=parent,
|
||||||
task=task,
|
task=task,
|
||||||
|
**extra
|
||||||
)
|
)
|
||||||
res = self.send(req)
|
res = self.send(req)
|
||||||
return res.response.id
|
return res.response.id
|
||||||
|
@ -6,6 +6,7 @@ from enum import Enum
|
|||||||
from threading import RLock, Thread
|
from threading import RLock, Thread
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
from six.moves.urllib.parse import quote
|
||||||
|
|
||||||
from ...backend_interface.task.development.worker import DevWorker
|
from ...backend_interface.task.development.worker import DevWorker
|
||||||
from ...backend_api import Session
|
from ...backend_api import Session
|
||||||
@ -96,6 +97,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
# this is an existing task, let's try to verify stuff
|
# this is an existing task, let's try to verify stuff
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
|
self._project_name = (self.project, project_name)
|
||||||
|
|
||||||
if running_remotely() or DevWorker.report_stdout:
|
if running_remotely() or DevWorker.report_stdout:
|
||||||
log_to_backend = False
|
log_to_backend = False
|
||||||
self._log_to_backend = log_to_backend
|
self._log_to_backend = log_to_backend
|
||||||
@ -223,14 +226,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
project_id = get_or_create_project(self, project_name, created_msg)
|
project_id = get_or_create_project(self, project_name, created_msg)
|
||||||
|
|
||||||
tags = [self._development_tag] if not running_remotely() else []
|
tags = [self._development_tag] if not running_remotely() else []
|
||||||
|
extra_properties = {'system_tags': tags} if Session.check_min_api_version('2.3') else {'tags': tags}
|
||||||
req = tasks.CreateRequest(
|
req = tasks.CreateRequest(
|
||||||
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
|
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
|
||||||
type=tasks.TaskTypeEnum(task_type.value),
|
type=tasks.TaskTypeEnum(task_type.value),
|
||||||
comment=created_msg,
|
comment=created_msg,
|
||||||
project=project_id,
|
project=project_id,
|
||||||
input={'view': {}},
|
input={'view': {}},
|
||||||
tags=tags,
|
**extra_properties
|
||||||
)
|
)
|
||||||
res = self.send(req)
|
res = self.send(req)
|
||||||
|
|
||||||
@ -369,7 +372,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
return self._reporter
|
return self._reporter
|
||||||
|
|
||||||
def _get_output_destination_suffix(self, extra_path=None):
|
def _get_output_destination_suffix(self, extra_path=None):
|
||||||
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
|
return '/'.join(quote(x, safe='[]{}()$^,.; -_+-=') for x in
|
||||||
|
(self.get_project_name(), '%s.%s' % (self.name, self.data.id), extra_path) if x)
|
||||||
|
|
||||||
def _reload(self):
|
def _reload(self):
|
||||||
""" Reload the task object from the backend """
|
""" Reload the task object from the backend """
|
||||||
@ -430,6 +434,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
|
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
|
||||||
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
|
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
|
||||||
providing a file://path/to/file URI)
|
providing a file://path/to/file URI)
|
||||||
|
|
||||||
:param model_uri: URI for the updated model weights file
|
:param model_uri: URI for the updated model weights file
|
||||||
:type model_uri: str
|
:type model_uri: str
|
||||||
:param name: Optional updated model name
|
:param name: Optional updated model name
|
||||||
@ -448,6 +453,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
|
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
|
||||||
task's output.destination property or call setup_upload()), than the model object associated with the task is
|
task's output.destination property or call setup_upload()), than the model object associated with the task is
|
||||||
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
|
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
|
||||||
|
|
||||||
:param model_file: Path to the updated model weights file
|
:param model_file: Path to the updated model weights file
|
||||||
:type model_file: str
|
:type model_file: str
|
||||||
:param name: Optional updated model name
|
:param name: Optional updated model name
|
||||||
@ -632,6 +638,21 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
execution.model_labels = enumeration
|
execution.model_labels = enumeration
|
||||||
self._edit(execution=execution)
|
self._edit(execution=execution)
|
||||||
|
|
||||||
|
def set_artifacts(self, artifacts_list=None):
|
||||||
|
"""
|
||||||
|
List of artifacts (tasks.Artifact) to update the task
|
||||||
|
|
||||||
|
:param list artifacts_list: list of artifacts (type tasks.Artifact)
|
||||||
|
"""
|
||||||
|
if not Session.check_min_api_version('2.3'):
|
||||||
|
return False
|
||||||
|
if not (isinstance(artifacts_list, (list, tuple))
|
||||||
|
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
|
||||||
|
raise ValueError('Expected artifacts to [tasks.Artifacts]')
|
||||||
|
execution = self.data.execution
|
||||||
|
execution.artifacts = artifacts_list
|
||||||
|
self._edit(execution=execution)
|
||||||
|
|
||||||
def _set_model_design(self, design=None):
|
def _set_model_design(self, design=None):
|
||||||
execution = self.data.execution
|
execution = self.data.execution
|
||||||
if design is not None:
|
if design is not None:
|
||||||
@ -677,7 +698,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
if self.project is None:
|
if self.project is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._project_name and self._project_name[0] == self.project:
|
if self._project_name and self._project_name[1] is not None and self._project_name[0] == self.project:
|
||||||
return self._project_name[1]
|
return self._project_name[1]
|
||||||
|
|
||||||
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
||||||
@ -689,8 +710,20 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
def get_tags(self):
|
def get_tags(self):
|
||||||
return self._get_task_property("tags")
|
return self._get_task_property("tags")
|
||||||
|
|
||||||
|
def set_system_tags(self, tags):
|
||||||
|
assert isinstance(tags, (list, tuple))
|
||||||
|
if Session.check_min_api_version('2.3'):
|
||||||
|
self._set_task_property("system_tags", tags)
|
||||||
|
self._edit(system_tags=self.data.system_tags)
|
||||||
|
else:
|
||||||
|
self._set_task_property("tags", tags)
|
||||||
|
self._edit(tags=self.data.tags)
|
||||||
|
|
||||||
def set_tags(self, tags):
|
def set_tags(self, tags):
|
||||||
assert isinstance(tags, (list, tuple))
|
assert isinstance(tags, (list, tuple))
|
||||||
|
if not Session.check_min_api_version('2.3'):
|
||||||
|
# not supported
|
||||||
|
return
|
||||||
self._set_task_property("tags", tags)
|
self._set_task_property("tags", tags)
|
||||||
self._edit(tags=self.data.tags)
|
self._edit(tags=self.data.tags)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def get_or_create_project(session, project_name, description=None):
|
|||||||
return res.response.id
|
return res.response.id
|
||||||
|
|
||||||
|
|
||||||
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True):
|
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True, sort_by_date=True):
|
||||||
if not results:
|
if not results:
|
||||||
if not raise_on_error:
|
if not raise_on_error:
|
||||||
return None
|
return None
|
||||||
@ -38,6 +38,13 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
|
|||||||
if len(results) > 1:
|
if len(results) > 1:
|
||||||
log.warn('More than one {entity} found when searching for `{query}`'
|
log.warn('More than one {entity} found when searching for `{query}`'
|
||||||
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
||||||
|
if sort_by_date:
|
||||||
|
# sort results based on timestamp and return the newest one
|
||||||
|
if hasattr(results[0], 'last_update'):
|
||||||
|
results = sorted(results, key=lambda x: int(x.last_update.strftime('%s')), reverse=True)
|
||||||
|
elif hasattr(results[0], 'created'):
|
||||||
|
results = sorted(results, key=lambda x: int(x.created.strftime('%s')), reverse=True)
|
||||||
|
|
||||||
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
|
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
|
||||||
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
||||||
|
|
||||||
|
@ -1,18 +1,31 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp, mkstemp
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event, RLock
|
||||||
|
from time import time
|
||||||
|
|
||||||
import numpy as np
|
import six
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..backend_interface.metrics.events import UploadEvent
|
||||||
|
from ..backend_api import Session
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
|
from ..backend_api.services import tasks
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pd = None
|
pd = None
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
np = None
|
||||||
|
|
||||||
|
|
||||||
class Artifacts(object):
|
class Artifacts(object):
|
||||||
@ -22,6 +35,7 @@ class Artifacts(object):
|
|||||||
_compression = 'gzip'
|
_compression = 'gzip'
|
||||||
# hashing constants
|
# hashing constants
|
||||||
_hash_block_size = 65536
|
_hash_block_size = 65536
|
||||||
|
_pd_artifact_type = 'data-audit-table'
|
||||||
|
|
||||||
class _ProxyDictWrite(dict):
|
class _ProxyDictWrite(dict):
|
||||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||||
@ -33,7 +47,7 @@ class Artifacts(object):
|
|||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
# check that value is of type pandas
|
# check that value is of type pandas
|
||||||
if isinstance(value, np.ndarray) or (pd and isinstance(value, pd.DataFrame)):
|
if pd and isinstance(value, pd.DataFrame):
|
||||||
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
super(Artifacts._ProxyDictWrite, self).__setitem__(key, value)
|
||||||
|
|
||||||
if self._artifacts_manager:
|
if self._artifacts_manager:
|
||||||
@ -72,6 +86,9 @@ class Artifacts(object):
|
|||||||
self._thread_pool = ThreadPool()
|
self._thread_pool = ThreadPool()
|
||||||
self._summary = ''
|
self._summary = ''
|
||||||
self._temp_folder = []
|
self._temp_folder = []
|
||||||
|
self._task_artifact_list = []
|
||||||
|
self._task_edit_lock = RLock()
|
||||||
|
self._storage_prefix = None
|
||||||
|
|
||||||
def register_artifact(self, name, artifact, metadata=None):
|
def register_artifact(self, name, artifact, metadata=None):
|
||||||
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
# currently we support pandas.DataFrame (which we will upload as csv.gz)
|
||||||
@ -86,6 +103,94 @@ class Artifacts(object):
|
|||||||
self._unregister_request.add(name)
|
self._unregister_request.add(name)
|
||||||
self.flush()
|
self.flush()
|
||||||
|
|
||||||
|
def upload_artifact(self, name, artifact_object=None, metadata=None, delete_after_upload=False):
|
||||||
|
if not Session.check_min_api_version('2.3'):
|
||||||
|
LoggerRoot.get_base_logger().warning('Artifacts not supported by your TRAINS-server version, '
|
||||||
|
'please upgrade to the latest server version')
|
||||||
|
return False
|
||||||
|
|
||||||
|
if name in self._artifacts_dict:
|
||||||
|
raise ValueError("Artifact by the name of {} is already registered, use register_artifact".format(name))
|
||||||
|
|
||||||
|
artifact_type_data = tasks.ArtifactTypeData()
|
||||||
|
use_filename_in_uri = True
|
||||||
|
if np and isinstance(artifact_object, np.ndarray):
|
||||||
|
artifact_type = 'numpy'
|
||||||
|
artifact_type_data.content_type = 'application/numpy'
|
||||||
|
artifact_type_data.preview = str(artifact_object.__repr__())
|
||||||
|
fd, local_filename = mkstemp(suffix='.npz')
|
||||||
|
os.close(fd)
|
||||||
|
np.savez_compressed(local_filename, **{name: artifact_object})
|
||||||
|
delete_after_upload = True
|
||||||
|
use_filename_in_uri = False
|
||||||
|
elif isinstance(artifact_object, Image.Image):
|
||||||
|
artifact_type = 'image'
|
||||||
|
artifact_type_data.content_type = 'image/png'
|
||||||
|
desc = str(artifact_object.__repr__())
|
||||||
|
artifact_type_data.preview = desc[1:desc.find(' at ')]
|
||||||
|
fd, local_filename = mkstemp(suffix='.png')
|
||||||
|
os.close(fd)
|
||||||
|
artifact_object.save(local_filename)
|
||||||
|
delete_after_upload = True
|
||||||
|
use_filename_in_uri = False
|
||||||
|
elif isinstance(artifact_object, dict):
|
||||||
|
artifact_type = 'JSON'
|
||||||
|
artifact_type_data.content_type = 'application/json'
|
||||||
|
preview = json.dumps(artifact_object, sort_keys=True, indent=4)
|
||||||
|
fd, local_filename = mkstemp(suffix='.json')
|
||||||
|
os.write(fd, bytes(preview.encode()))
|
||||||
|
os.close(fd)
|
||||||
|
artifact_type_data.preview = preview
|
||||||
|
delete_after_upload = True
|
||||||
|
use_filename_in_uri = False
|
||||||
|
elif isinstance(artifact_object, six.string_types) or isinstance(artifact_object, Path):
|
||||||
|
if isinstance(artifact_object, Path):
|
||||||
|
artifact_object = artifact_object.as_posix()
|
||||||
|
artifact_type = 'custom'
|
||||||
|
artifact_type_data.content_type = mimetypes.guess_type(artifact_object)[0]
|
||||||
|
local_filename = artifact_object
|
||||||
|
else:
|
||||||
|
raise ValueError("Artifact type {} not supported".format(type(artifact_object)))
|
||||||
|
|
||||||
|
# remove from existing list, if exists
|
||||||
|
for artifact in self._task_artifact_list:
|
||||||
|
if artifact.key == name:
|
||||||
|
if artifact.type == self._pd_artifact_type:
|
||||||
|
raise ValueError("Artifact of name {} already registered, "
|
||||||
|
"use register_artifact instead".format(name))
|
||||||
|
|
||||||
|
self._task_artifact_list.remove(artifact)
|
||||||
|
break
|
||||||
|
|
||||||
|
# check that the file to upload exists
|
||||||
|
local_filename = Path(local_filename).absolute()
|
||||||
|
if not local_filename.exists() or not local_filename.is_file():
|
||||||
|
LoggerRoot.get_base_logger().warning('Artifact upload failed, cannot find file {}'.format(
|
||||||
|
local_filename.as_posix()))
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_hash, _ = self.sha256sum(local_filename.as_posix())
|
||||||
|
timestamp = int(time())
|
||||||
|
file_size = local_filename.stat().st_size
|
||||||
|
|
||||||
|
uri = self._upload_local_file(local_filename, name,
|
||||||
|
delete_after_upload=delete_after_upload, use_filename=use_filename_in_uri)
|
||||||
|
|
||||||
|
artifact = tasks.Artifact(key=name, type=artifact_type,
|
||||||
|
uri=uri,
|
||||||
|
content_size=file_size,
|
||||||
|
hash=file_hash,
|
||||||
|
timestamp=timestamp,
|
||||||
|
type_data=artifact_type_data,
|
||||||
|
display_data=[(str(k), str(v)) for k, v in metadata.items()] if metadata else None)
|
||||||
|
|
||||||
|
# update task artifacts
|
||||||
|
with self._task_edit_lock:
|
||||||
|
self._task_artifact_list.append(artifact)
|
||||||
|
self._task.set_artifacts(self._task_artifact_list)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
# start the thread if it hasn't already:
|
# start the thread if it hasn't already:
|
||||||
self._start()
|
self._start()
|
||||||
@ -119,19 +224,20 @@ class Artifacts(object):
|
|||||||
while not self._exit_flag:
|
while not self._exit_flag:
|
||||||
self._flush_event.wait(self._flush_frequency_sec)
|
self._flush_event.wait(self._flush_frequency_sec)
|
||||||
self._flush_event.clear()
|
self._flush_event.clear()
|
||||||
try:
|
|
||||||
artifact_keys = list(self._artifacts_dict.keys())
|
artifact_keys = list(self._artifacts_dict.keys())
|
||||||
for name in artifact_keys:
|
for name in artifact_keys:
|
||||||
self._upload_artifacts(name)
|
try:
|
||||||
|
self._upload_data_audit_artifacts(name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LoggerRoot.get_base_logger().warning(str(e))
|
LoggerRoot.get_base_logger().warning(str(e))
|
||||||
|
|
||||||
# create summary
|
# create summary
|
||||||
self._summary = self._get_statistics()
|
self._summary = self._get_statistics()
|
||||||
|
|
||||||
def _upload_artifacts(self, name):
|
def _upload_data_audit_artifacts(self, name):
|
||||||
logger = self._task.get_logger()
|
logger = self._task.get_logger()
|
||||||
artifact = self._artifacts_dict.get(name)
|
pd_artifact = self._artifacts_dict.get(name)
|
||||||
|
pd_metadata = self._artifacts_dict.get_metadata(name)
|
||||||
|
|
||||||
# remove from artifacts watch list
|
# remove from artifacts watch list
|
||||||
if name in self._unregister_request:
|
if name in self._unregister_request:
|
||||||
@ -141,15 +247,15 @@ class Artifacts(object):
|
|||||||
pass
|
pass
|
||||||
self._artifacts_dict.unregister_artifact(name)
|
self._artifacts_dict.unregister_artifact(name)
|
||||||
|
|
||||||
if artifact is None:
|
if pd_artifact is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
local_csv = (Path(self._get_temp_folder()) / (name + self._save_format)).absolute()
|
||||||
if local_csv.exists():
|
if local_csv.exists():
|
||||||
# we are still uploading... get another temp folder
|
# we are still uploading... get another temp folder
|
||||||
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
local_csv = (Path(self._get_temp_folder(force_new=True)) / (name + self._save_format)).absolute()
|
||||||
artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
pd_artifact.to_csv(local_csv.as_posix(), index=False, compression=self._compression)
|
||||||
current_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
current_sha2, file_sha2 = self.sha256sum(local_csv.as_posix(), skip_header=32)
|
||||||
if name in self._last_artifacts_upload:
|
if name in self._last_artifacts_upload:
|
||||||
previous_sha2 = self._last_artifacts_upload[name]
|
previous_sha2 = self._last_artifacts_upload[name]
|
||||||
if previous_sha2 == current_sha2:
|
if previous_sha2 == current_sha2:
|
||||||
@ -157,19 +263,75 @@ class Artifacts(object):
|
|||||||
local_csv.unlink()
|
local_csv.unlink()
|
||||||
return
|
return
|
||||||
self._last_artifacts_upload[name] = current_sha2
|
self._last_artifacts_upload[name] = current_sha2
|
||||||
# now upload and delete at the end.
|
|
||||||
|
# If old trains-server, upload as debug image
|
||||||
|
if not Session.check_min_api_version('2.3'):
|
||||||
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
logger.report_image_and_upload(title='artifacts', series=name, path=local_csv.as_posix(),
|
||||||
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
delete_after_upload=True, iteration=self._task.get_last_iteration(),
|
||||||
max_image_history=2)
|
max_image_history=2)
|
||||||
|
return
|
||||||
|
|
||||||
def _get_statistics(self):
|
# Find our artifact
|
||||||
|
artifact = None
|
||||||
|
for an_artifact in self._task_artifact_list:
|
||||||
|
if an_artifact.key == name:
|
||||||
|
artifact = an_artifact
|
||||||
|
break
|
||||||
|
|
||||||
|
file_size = local_csv.stat().st_size
|
||||||
|
|
||||||
|
# upload file
|
||||||
|
uri = self._upload_local_file(local_csv, name, delete_after_upload=True)
|
||||||
|
|
||||||
|
# update task artifacts
|
||||||
|
with self._task_edit_lock:
|
||||||
|
if not artifact:
|
||||||
|
artifact = tasks.Artifact(key=name, type=self._pd_artifact_type)
|
||||||
|
self._task_artifact_list.append(artifact)
|
||||||
|
artifact_type_data = tasks.ArtifactTypeData()
|
||||||
|
|
||||||
|
artifact_type_data.data_hash = current_sha2
|
||||||
|
artifact_type_data.content_type = "text/csv"
|
||||||
|
artifact_type_data.preview = str(pd_artifact.__repr__())+'\n\n'+self._get_statistics({name: pd_artifact})
|
||||||
|
|
||||||
|
artifact.type_data = artifact_type_data
|
||||||
|
artifact.uri = uri
|
||||||
|
artifact.content_size = file_size
|
||||||
|
artifact.hash = file_sha2
|
||||||
|
artifact.timestamp = int(time())
|
||||||
|
artifact.display_data = [(str(k), str(v)) for k, v in pd_metadata.items()] if pd_metadata else None
|
||||||
|
|
||||||
|
self._task.set_artifacts(self._task_artifact_list)
|
||||||
|
|
||||||
|
def _upload_local_file(self, local_file, name, delete_after_upload=False, use_filename=True):
|
||||||
|
"""
|
||||||
|
Upload local file and return uri of the uploaded file (uploading in the background)
|
||||||
|
"""
|
||||||
|
upload_uri = self._task.get_logger().get_default_upload_destination()
|
||||||
|
if not isinstance(local_file, Path):
|
||||||
|
local_file = Path(local_file)
|
||||||
|
ev = UploadEvent(metric='artifacts', variant=name,
|
||||||
|
image_data=None, upload_uri=upload_uri,
|
||||||
|
local_image_path=local_file.as_posix(),
|
||||||
|
delete_after_upload=delete_after_upload,
|
||||||
|
override_filename=os.path.splitext(local_file.name)[0] if use_filename else None,
|
||||||
|
override_storage_key_prefix=self._get_storage_uri_prefix())
|
||||||
|
_, uri = ev.get_target_full_upload_uri(upload_uri)
|
||||||
|
|
||||||
|
# send for upload
|
||||||
|
self._task.reporter._report(ev)
|
||||||
|
|
||||||
|
return uri
|
||||||
|
|
||||||
|
def _get_statistics(self, artifacts_dict=None):
|
||||||
summary = ''
|
summary = ''
|
||||||
|
artifacts_dict = artifacts_dict or self._artifacts_dict
|
||||||
thread_pool = ThreadPool()
|
thread_pool = ThreadPool()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# build hash row sets
|
# build hash row sets
|
||||||
artifacts_summary = []
|
artifacts_summary = []
|
||||||
for a_name, a_df in self._artifacts_dict.items():
|
for a_name, a_df in artifacts_dict.items():
|
||||||
if not pd or not isinstance(a_df, pd.DataFrame):
|
if not pd or not isinstance(a_df, pd.DataFrame):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -206,16 +368,30 @@ class Artifacts(object):
|
|||||||
return new_temp
|
return new_temp
|
||||||
return self._temp_folder[0]
|
return self._temp_folder[0]
|
||||||
|
|
||||||
|
def _get_storage_uri_prefix(self):
|
||||||
|
if not self._storage_prefix:
|
||||||
|
self._storage_prefix = self._task._get_output_destination_suffix()
|
||||||
|
return self._storage_prefix
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sha256sum(filename, skip_header=0):
|
def sha256sum(filename, skip_header=0):
|
||||||
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
# create sha2 of the file, notice we skip the header of the file (32 bytes)
|
||||||
# because sometimes that is the only change
|
# because sometimes that is the only change
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
|
file_hash = hashlib.sha256()
|
||||||
b = bytearray(Artifacts._hash_block_size)
|
b = bytearray(Artifacts._hash_block_size)
|
||||||
mv = memoryview(b)
|
mv = memoryview(b)
|
||||||
|
try:
|
||||||
with open(filename, 'rb', buffering=0) as f:
|
with open(filename, 'rb', buffering=0) as f:
|
||||||
# skip header
|
# skip header
|
||||||
f.read(skip_header)
|
if skip_header:
|
||||||
|
file_hash.update(f.read(skip_header))
|
||||||
for n in iter(lambda: f.readinto(mv), 0):
|
for n in iter(lambda: f.readinto(mv), 0):
|
||||||
h.update(mv[:n])
|
h.update(mv[:n])
|
||||||
return h.hexdigest()
|
if skip_header:
|
||||||
|
file_hash.update(mv[:n])
|
||||||
|
except Exception as e:
|
||||||
|
LoggerRoot.get_base_logger().warning(str(e))
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return h.hexdigest(), file_hash.hexdigest() if skip_header else None
|
||||||
|
@ -109,7 +109,7 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
# update api version
|
# update api version
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
PatchedMatplotlib._support_image_plot = Session.api_version > '2.1'
|
PatchedMatplotlib._support_image_plot = Session.check_min_api_version('2.2')
|
||||||
|
|
||||||
# create plotly renderer
|
# create plotly renderer
|
||||||
try:
|
try:
|
||||||
|
@ -572,7 +572,7 @@ class Logger(object):
|
|||||||
|
|
||||||
# if task was not started, we have to start it
|
# if task was not started, we have to start it
|
||||||
self._start_task_if_needed()
|
self._start_task_if_needed()
|
||||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
upload_uri = self.get_default_upload_destination()
|
||||||
if not upload_uri:
|
if not upload_uri:
|
||||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||||
@ -619,7 +619,7 @@ class Logger(object):
|
|||||||
|
|
||||||
# if task was not started, we have to start it
|
# if task was not started, we have to start it
|
||||||
self._start_task_if_needed()
|
self._start_task_if_needed()
|
||||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
upload_uri = self.get_default_upload_destination()
|
||||||
if not upload_uri:
|
if not upload_uri:
|
||||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||||
@ -664,7 +664,7 @@ class Logger(object):
|
|||||||
|
|
||||||
# if task was not started, we have to start it
|
# if task was not started, we have to start it
|
||||||
self._start_task_if_needed()
|
self._start_task_if_needed()
|
||||||
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
|
upload_uri = self.get_default_upload_destination()
|
||||||
if not upload_uri:
|
if not upload_uri:
|
||||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||||
@ -705,6 +705,19 @@ class Logger(object):
|
|||||||
|
|
||||||
self._default_upload_destination = uri
|
self._default_upload_destination = uri
|
||||||
|
|
||||||
|
def get_default_upload_destination(self):
|
||||||
|
"""
|
||||||
|
Get the uri to upload all the debug images to.
|
||||||
|
|
||||||
|
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||||
|
a link to the uploaded image is sent in the report
|
||||||
|
Notice: credentials for the upload destination will be pooled from the
|
||||||
|
global configuration file (i.e. ~/trains.conf)
|
||||||
|
|
||||||
|
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
|
||||||
|
"""
|
||||||
|
return self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""
|
"""
|
||||||
Flush cached reports and console outputs to backend.
|
Flush cached reports and console outputs to backend.
|
||||||
|
@ -6,6 +6,8 @@ from tempfile import mkdtemp, mkstemp
|
|||||||
|
|
||||||
import pyparsing
|
import pyparsing
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from .backend_api import Session
|
||||||
from .backend_api.services import models
|
from .backend_api.services import models
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
from pyhocon import ConfigFactory, HOCONConverter
|
from pyhocon import ConfigFactory, HOCONConverter
|
||||||
@ -276,10 +278,13 @@ class BaseModel(object):
|
|||||||
def _set_package_tag(self):
|
def _set_package_tag(self):
|
||||||
if self._package_tag not in self.tags:
|
if self._package_tag not in self.tags:
|
||||||
self.tags.append(self._package_tag)
|
self.tags.append(self._package_tag)
|
||||||
self._get_base_model().update(tags=self.tags)
|
self._get_base_model().edit(tags=self.tags)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _config_dict_to_text(config):
|
def _config_dict_to_text(config):
|
||||||
|
# if already string return as is
|
||||||
|
if isinstance(config, six.string_types):
|
||||||
|
return config
|
||||||
if not isinstance(config, dict):
|
if not isinstance(config, dict):
|
||||||
raise ValueError("Model configuration only supports dictionary objects")
|
raise ValueError("Model configuration only supports dictionary objects")
|
||||||
try:
|
try:
|
||||||
@ -372,10 +377,12 @@ class InputModel(BaseModel):
|
|||||||
weights_url = StorageHelper.conform_url(weights_url)
|
weights_url = StorageHelper.conform_url(weights_url)
|
||||||
if not weights_url:
|
if not weights_url:
|
||||||
raise ValueError("Please provide a valid weights_url parameter")
|
raise ValueError("Please provide a valid weights_url parameter")
|
||||||
|
extra = {'system_tags': ["-" + ARCHIVED_TAG]} \
|
||||||
|
if Session.check_min_api_version('2.3') else {'tags': ["-" + ARCHIVED_TAG]}
|
||||||
result = _Model._get_default_session().send(models.GetAllRequest(
|
result = _Model._get_default_session().send(models.GetAllRequest(
|
||||||
uri=[weights_url],
|
uri=[weights_url],
|
||||||
only_fields=["id", "name"],
|
only_fields=["id", "name"],
|
||||||
tags=["-" + ARCHIVED_TAG]
|
**extra
|
||||||
))
|
))
|
||||||
|
|
||||||
if result.response.models:
|
if result.response.models:
|
||||||
@ -910,7 +917,7 @@ class OutputModel(BaseModel):
|
|||||||
|
|
||||||
if self.id:
|
if self.id:
|
||||||
# update the model object (this will happen if we resumed a training task)
|
# update the model object (this will happen if we resumed a training task)
|
||||||
result = self._get_force_base_model().update(design=config_text, task_id=self._task.id)
|
result = self._get_force_base_model().edit(design=config_text)
|
||||||
else:
|
else:
|
||||||
self._floating_data.design = _Model._wrap_design(config_text)
|
self._floating_data.design = _Model._wrap_design(config_text)
|
||||||
result = Waitable()
|
result = Waitable()
|
||||||
@ -936,7 +943,7 @@ class OutputModel(BaseModel):
|
|||||||
|
|
||||||
if self.id:
|
if self.id:
|
||||||
# update the model object (this will happen if we resumed a training task)
|
# update the model object (this will happen if we resumed a training task)
|
||||||
result = self._get_force_base_model().update(labels=labels, task_id=self._task.id)
|
result = self._get_force_base_model().edit(labels=labels)
|
||||||
else:
|
else:
|
||||||
self._floating_data.labels = labels
|
self._floating_data.labels = labels
|
||||||
result = Waitable()
|
result = Waitable()
|
||||||
|
@ -390,8 +390,10 @@ class Task(_Task):
|
|||||||
task_id=default_task_id,
|
task_id=default_task_id,
|
||||||
log_to_backend=True,
|
log_to_backend=True,
|
||||||
)
|
)
|
||||||
|
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||||
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
if ((str(task.status) in (str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
||||||
or (ARCHIVED_TAG in task.data.tags) or task.output_model_id):
|
or task.output_model_id or (ARCHIVED_TAG in task_tags)
|
||||||
|
or (cls._development_tag not in task_tags)):
|
||||||
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
|
# If the task is published or closed, we shouldn't reset it so we can't use it in dev mode
|
||||||
# If the task is archived, or already has an output model,
|
# If the task is archived, or already has an output model,
|
||||||
# we shouldn't use it in development mode either
|
# we shouldn't use it in development mode either
|
||||||
@ -401,7 +403,7 @@ class Task(_Task):
|
|||||||
# reset the task, so we can update it
|
# reset the task, so we can update it
|
||||||
task.reset(set_started_on_success=False, force=False)
|
task.reset(set_started_on_success=False, force=False)
|
||||||
# set development tags
|
# set development tags
|
||||||
task.set_tags([cls._development_tag])
|
task.set_system_tags([cls._development_tag])
|
||||||
# clear task parameters, they are not cleared by the Task reset
|
# clear task parameters, they are not cleared by the Task reset
|
||||||
task.set_parameters({}, __update=False)
|
task.set_parameters({}, __update=False)
|
||||||
# clear the comment, it is not cleared on reset
|
# clear the comment, it is not cleared on reset
|
||||||
@ -410,6 +412,7 @@ class Task(_Task):
|
|||||||
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
|
task.set_input_model(model_id='', update_task_design=False, update_task_labels=False)
|
||||||
task.set_model_config(config_text='')
|
task.set_model_config(config_text='')
|
||||||
task.set_model_label_enumeration({})
|
task.set_model_label_enumeration({})
|
||||||
|
task.set_artifacts([])
|
||||||
|
|
||||||
except (Exception, ValueError):
|
except (Exception, ValueError):
|
||||||
# we failed reusing task, create a new one
|
# we failed reusing task, create a new one
|
||||||
@ -480,7 +483,7 @@ class Task(_Task):
|
|||||||
if value and value != self.storage_uri:
|
if value and value != self.storage_uri:
|
||||||
from .storage.helper import StorageHelper
|
from .storage.helper import StorageHelper
|
||||||
helper = StorageHelper.get(value)
|
helper = StorageHelper.get(value)
|
||||||
helper.check_write_permissions()
|
helper.check_write_permissions(value)
|
||||||
self.storage_uri = value
|
self.storage_uri = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -658,20 +661,24 @@ class Task(_Task):
|
|||||||
"""
|
"""
|
||||||
self._artifacts_manager.unregister_artifact(name=name)
|
self._artifacts_manager.unregister_artifact(name=name)
|
||||||
|
|
||||||
def upload_artifact(self, name, artifact_object=None, artifact_file=None, metadata=None):
|
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
|
||||||
"""
|
"""
|
||||||
Add static artifact to Task. Artifact file/object will be uploaded in the background
|
Add static artifact to Task. Artifact file/object will be uploaded in the background
|
||||||
|
Raise ValueError if artifact_object is not supported
|
||||||
|
|
||||||
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
|
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
|
||||||
:param object artifact_object: Artifact object to upload. Currently supports Numpy, PIL.Image.
|
:param object artifact_object: Artifact object to upload. Currently supports:
|
||||||
Numpy will be stored as .npz, and Image as .png file.
|
- string / pathlib2.Path are treated as path to artifact file to upload
|
||||||
Use None if uploading a file directly with 'artifact_file'.
|
- dict will be stored as .json,
|
||||||
:param str artifact_file: path to artifact file to upload. None means not applicable.
|
- numpy.ndarray will be stored as .npz,
|
||||||
Notice you wither artifact object or artifact_file
|
- PIL.Image will be stored to .png file and uploaded
|
||||||
:param dict metadata: Simple key/value dictionary to store on the artifact
|
:param dict metadata: Simple key/value dictionary to store on the artifact
|
||||||
:return: True if artifact is supported
|
:param bool delete_after_upload: If True local artifact will be deleted
|
||||||
|
(only applies if artifact_object is a local file)
|
||||||
|
:return: True if artifact will be uploaded
|
||||||
"""
|
"""
|
||||||
raise ValueError("Not implemented yet")
|
return self._artifacts_manager.upload_artifact(name=name, artifact_object=artifact_object,
|
||||||
|
metadata=metadata, delete_after_upload=delete_after_upload)
|
||||||
|
|
||||||
def is_current_task(self):
|
def is_current_task(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user