mirror of
https://github.com/clearml/clearml
synced 2025-04-16 21:42:10 +00:00
Refactor code
This commit is contained in:
parent
ecf6a4df2a
commit
4cd8857c0d
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,8 +11,8 @@ build/
|
|||||||
dist/
|
dist/
|
||||||
*.egg-info
|
*.egg-info
|
||||||
.env
|
.env
|
||||||
venv/
|
|
||||||
.venv/
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
# example data
|
# example data
|
||||||
examples/runs/
|
examples/runs/
|
||||||
|
@ -218,7 +218,6 @@ class PipelineController(object):
|
|||||||
def serialize(obj):
|
def serialize(obj):
|
||||||
import dill
|
import dill
|
||||||
return dill.dumps(obj)
|
return dill.dumps(obj)
|
||||||
|
|
||||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||||
which represents the serialized object. This function should return the deserialized object.
|
which represents the serialized object. This function should return the deserialized object.
|
||||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||||
@ -1157,6 +1156,7 @@ class PipelineController(object):
|
|||||||
# type: (bool, str) -> bool
|
# type: (bool, str) -> bool
|
||||||
"""
|
"""
|
||||||
Evaluate whether or not the pipeline is successful
|
Evaluate whether or not the pipeline is successful
|
||||||
|
|
||||||
:param fail_on_step_fail: If True (default), evaluate the pipeline steps' status to assess if the pipeline
|
:param fail_on_step_fail: If True (default), evaluate the pipeline steps' status to assess if the pipeline
|
||||||
is successful. If False, only evaluate the controller
|
is successful. If False, only evaluate the controller
|
||||||
:param fail_condition: Must be one of the following: 'all' (default), 'failed' or 'aborted'. If 'failed', this
|
:param fail_condition: Must be one of the following: 'all' (default), 'failed' or 'aborted'. If 'failed', this
|
||||||
@ -1175,18 +1175,14 @@ class PipelineController(object):
|
|||||||
success_status = [Task.TaskStatusEnum.completed, Task.TaskStatusEnum.failed]
|
success_status = [Task.TaskStatusEnum.completed, Task.TaskStatusEnum.failed]
|
||||||
else:
|
else:
|
||||||
raise UsageError("fail_condition needs to be one of the following: 'all', 'failed', 'aborted'")
|
raise UsageError("fail_condition needs to be one of the following: 'all', 'failed', 'aborted'")
|
||||||
|
|
||||||
if self._task.status not in success_status:
|
if self._task.status not in success_status:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not fail_on_step_fail:
|
if not fail_on_step_fail:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self._update_nodes_status()
|
self._update_nodes_status()
|
||||||
for node in self._nodes.values():
|
for node in self._nodes.values():
|
||||||
if node.status not in success_status:
|
if node.status not in success_status:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def elapsed(self):
|
def elapsed(self):
|
||||||
|
@ -139,7 +139,7 @@ class TaskTrigger(BaseTrigger):
|
|||||||
raise ValueError("You must provide metric/variant/threshold")
|
raise ValueError("You must provide metric/variant/threshold")
|
||||||
valid_status = [str(s) for s in Task.TaskStatusEnum]
|
valid_status = [str(s) for s in Task.TaskStatusEnum]
|
||||||
if self.on_status and not all(s in valid_status for s in self.on_status):
|
if self.on_status and not all(s in valid_status for s in self.on_status):
|
||||||
raise ValueError("You on_status contains invalid status value: {}".format(self.on_status))
|
raise ValueError("Your on_status contains invalid status value: {}".format(self.on_status))
|
||||||
valid_signs = ['min', 'minimum', 'max', 'maximum']
|
valid_signs = ['min', 'minimum', 'max', 'maximum']
|
||||||
if self.value_sign and self.value_sign not in valid_signs:
|
if self.value_sign and self.value_sign not in valid_signs:
|
||||||
raise ValueError("Invalid value_sign `{}`, valid options are: {}".format(self.value_sign, valid_signs))
|
raise ValueError("Invalid value_sign `{}`, valid options are: {}".format(self.value_sign, valid_signs))
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
auth service
|
auth service
|
||||||
|
|
||||||
This service provides authentication management and authorization
|
This service provides authentication management and authorization
|
||||||
|
|
||||||
validation for the entire system.
|
validation for the entire system.
|
||||||
"""
|
"""
|
||||||
import six
|
import six
|
||||||
|
@ -40,8 +40,10 @@ for a very long time for a non-responding or mis-configured server
|
|||||||
"""
|
"""
|
||||||
ENV_API_EXTRA_RETRY_CODES = EnvEntry("CLEARML_API_EXTRA_RETRY_CODES")
|
ENV_API_EXTRA_RETRY_CODES = EnvEntry("CLEARML_API_EXTRA_RETRY_CODES")
|
||||||
|
|
||||||
|
|
||||||
ENV_FORCE_MAX_API_VERSION = EnvEntry("CLEARML_FORCE_MAX_API_VERSION", type=str)
|
ENV_FORCE_MAX_API_VERSION = EnvEntry("CLEARML_FORCE_MAX_API_VERSION", type=str)
|
||||||
|
|
||||||
|
|
||||||
class MissingConfigError(ValueError):
|
class MissingConfigError(ValueError):
|
||||||
def __init__(self, message=None):
|
def __init__(self, message=None):
|
||||||
if message is None:
|
if message is None:
|
||||||
|
@ -94,7 +94,6 @@ class CompoundRequest(Request):
|
|||||||
if self._item_prop_name in dict_properties:
|
if self._item_prop_name in dict_properties:
|
||||||
del dict_properties[self._item_prop_name]
|
del dict_properties[self._item_prop_name]
|
||||||
dict_.update(dict_properties)
|
dict_.update(dict_properties)
|
||||||
|
|
||||||
return dict_
|
return dict_
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
@ -134,6 +134,7 @@ class Session(TokenManager):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.__class__._sessions_weakrefs.append(weakref.ref(self))
|
self.__class__._sessions_weakrefs.append(weakref.ref(self))
|
||||||
|
|
||||||
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
if self._verbose and not self._logger:
|
if self._verbose and not self._logger:
|
||||||
@ -147,7 +148,6 @@ class Session(TokenManager):
|
|||||||
self.__init_host = host
|
self.__init_host = host
|
||||||
self.__init_http_retries_config = http_retries_config
|
self.__init_http_retries_config = http_retries_config
|
||||||
self.__token_manager_kwargs = kwargs
|
self.__token_manager_kwargs = kwargs
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
self.config = config
|
self.config = config
|
||||||
else:
|
else:
|
||||||
@ -162,21 +162,21 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
self._ssl_error_count_verbosity = self.config.get(
|
self._ssl_error_count_verbosity = self.config.get(
|
||||||
"api.ssl_error_count_verbosity", self._ssl_error_count_verbosity)
|
"api.ssl_error_count_verbosity", self._ssl_error_count_verbosity)
|
||||||
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
|
|
||||||
|
|
||||||
|
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
|
||||||
if not self.__host:
|
if not self.__host:
|
||||||
raise ValueError("ClearML host was not set, check your configuration file or environment variable")
|
raise ValueError("ClearML host was not set, check your configuration file or environment variable")
|
||||||
|
|
||||||
self.__host = self.__host.strip("/")
|
self.__host = self.__host.strip("/")
|
||||||
self.__http_retries_config = self.__init_http_retries_config or self.config.get(
|
self.__http_retries_config = self.__init_http_retries_config or self.config.get(
|
||||||
"api.http.retries", ConfigTree()).as_plain_ordered_dict()
|
"api.http.retries", ConfigTree()).as_plain_ordered_dict()
|
||||||
|
|
||||||
self.__http_retries_config["status_forcelist"] = self._get_retry_codes()
|
self.__http_retries_config["status_forcelist"] = self._get_retry_codes()
|
||||||
self.__http_retries_config["config"] = self.config
|
self.__http_retries_config["config"] = self.config
|
||||||
self.__http_session = get_http_session_with_retry(**self.__http_retries_config)
|
self.__http_session = get_http_session_with_retry(**self.__http_retries_config)
|
||||||
self.__http_session.write_timeout = self._write_session_timeout
|
self.__http_session.write_timeout = self._write_session_timeout
|
||||||
self.__http_session.request_size_threshold = self._write_session_data_size
|
self.__http_session.request_size_threshold = self._write_session_data_size
|
||||||
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
|
||||||
|
|
||||||
|
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||||
if not self.__max_req_size:
|
if not self.__max_req_size:
|
||||||
raise ValueError("missing max request size")
|
raise ValueError("missing max request size")
|
||||||
|
|
||||||
@ -186,7 +186,6 @@ class Session(TokenManager):
|
|||||||
req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None)
|
req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None)
|
||||||
self.__auth_token = None
|
self.__auth_token = None
|
||||||
self._update_default_api_method()
|
self._update_default_api_method()
|
||||||
|
|
||||||
if ENV_AUTH_TOKEN.get():
|
if ENV_AUTH_TOKEN.get():
|
||||||
self.__access_key = self.__secret_key = None
|
self.__access_key = self.__secret_key = None
|
||||||
self.__auth_token = ENV_AUTH_TOKEN.get()
|
self.__auth_token = ENV_AUTH_TOKEN.get()
|
||||||
@ -203,9 +202,11 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
if not self.secret_key and not self.access_key and not self.__auth_token:
|
if not self.secret_key and not self.access_key and not self.__auth_token:
|
||||||
raise MissingConfigError()
|
raise MissingConfigError()
|
||||||
|
|
||||||
super(Session, self).__init__(
|
super(Session, self).__init__(
|
||||||
**self.__token_manager_kwargs,
|
**self.__token_manager_kwargs,
|
||||||
req_token_expiration_sec=req_token_expiration_sec,
|
token_expiration_threshold_sec=token_expiration_threshold_sec,
|
||||||
|
req_token_expiration_sec=req_token_expiration_sec
|
||||||
)
|
)
|
||||||
self.refresh_token()
|
self.refresh_token()
|
||||||
|
|
||||||
@ -633,6 +634,7 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
return call_result
|
return call_result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def _make_all_sessions_go_online(cls):
|
def _make_all_sessions_go_online(cls):
|
||||||
for active_session in cls._get_all_active_sessions():
|
for active_session in cls._get_all_active_sessions():
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
@ -647,7 +649,6 @@ class Session(TokenManager):
|
|||||||
if session:
|
if session:
|
||||||
active_sessions.append(session)
|
active_sessions.append(session)
|
||||||
new_sessions_weakrefs.append(session_weakref)
|
new_sessions_weakrefs.append(session_weakref)
|
||||||
|
|
||||||
cls._sessions_weakrefs = session_weakref
|
cls._sessions_weakrefs = session_weakref
|
||||||
return active_sessions
|
return active_sessions
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from time import time
|
|||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ...backend_api import Session
|
||||||
from ...backend_api.services import events as api_events
|
from ...backend_api.services import events as api_events
|
||||||
from ..base import InterfaceBase
|
from ..base import InterfaceBase
|
||||||
from ...config import config, deferred_config
|
from ...config import config, deferred_config
|
||||||
|
@ -271,9 +271,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
self._for_model = for_model
|
self._for_model = for_model
|
||||||
flush_threshold = config.get("development.worker.report_event_flush_threshold", 100)
|
flush_threshold = config.get("development.worker.report_event_flush_threshold", 100)
|
||||||
self._report_service = BackgroundReportService(
|
self._report_service = BackgroundReportService(
|
||||||
task=task, async_enable=async_enable, metrics=metrics,
|
task=task,
|
||||||
|
async_enable=async_enable,
|
||||||
|
metrics=metrics,
|
||||||
flush_frequency=self._flush_frequency,
|
flush_frequency=self._flush_frequency,
|
||||||
flush_threshold=flush_threshold, for_model=for_model)
|
flush_threshold=flush_threshold,
|
||||||
|
for_model=for_model,
|
||||||
|
)
|
||||||
self._report_service.start()
|
self._report_service.start()
|
||||||
|
|
||||||
def _set_storage_uri(self, value):
|
def _set_storage_uri(self, value):
|
||||||
@ -355,8 +359,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
:param iter: Iteration number
|
:param iter: Iteration number
|
||||||
:type iter: int
|
:type iter: int
|
||||||
"""
|
"""
|
||||||
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value,
|
ev = ScalarEvent(
|
||||||
iter=iter)
|
metric=self._normalize_name(title),
|
||||||
|
variant=self._normalize_name(series),
|
||||||
|
value=value,
|
||||||
|
iter=iter
|
||||||
|
)
|
||||||
self._report(ev)
|
self._report(ev)
|
||||||
|
|
||||||
def report_vector(self, title, series, values, iter):
|
def report_vector(self, title, series, values, iter):
|
||||||
@ -457,8 +465,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
elif not isinstance(plot, six.string_types):
|
elif not isinstance(plot, six.string_types):
|
||||||
raise ValueError('Plot should be a string or a dict')
|
raise ValueError('Plot should be a string or a dict')
|
||||||
|
|
||||||
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series),
|
ev = PlotEvent(
|
||||||
plot_str=plot, iter=iter)
|
metric=self._normalize_name(title),
|
||||||
|
variant=self._normalize_name(series),
|
||||||
|
plot_str=plot,
|
||||||
|
iter=iter
|
||||||
|
)
|
||||||
self._report(ev)
|
self._report(ev)
|
||||||
|
|
||||||
def report_image(self, title, series, src, iter):
|
def report_image(self, title, series, src, iter):
|
||||||
|
@ -12,7 +12,7 @@ from ..storage import StorageManager
|
|||||||
from ..storage.helper import StorageHelper
|
from ..storage.helper import StorageHelper
|
||||||
from ..utilities.async_manager import AsyncManagerMixin
|
from ..utilities.async_manager import AsyncManagerMixin
|
||||||
|
|
||||||
ModelPackage = namedtuple('ModelPackage', 'weights design')
|
ModelPackage = namedtuple("ModelPackage", "weights design")
|
||||||
|
|
||||||
|
|
||||||
class ModelDoesNotExistError(Exception):
|
class ModelDoesNotExistError(Exception):
|
||||||
@ -22,12 +22,12 @@ class ModelDoesNotExistError(Exception):
|
|||||||
class _StorageUriMixin(object):
|
class _StorageUriMixin(object):
|
||||||
@property
|
@property
|
||||||
def upload_storage_uri(self):
|
def upload_storage_uri(self):
|
||||||
""" A URI into which models are uploaded """
|
"""A URI into which models are uploaded"""
|
||||||
return self._upload_storage_uri
|
return self._upload_storage_uri
|
||||||
|
|
||||||
@upload_storage_uri.setter
|
@upload_storage_uri.setter
|
||||||
def upload_storage_uri(self, value):
|
def upload_storage_uri(self, value):
|
||||||
self._upload_storage_uri = value.rstrip('/') if value else None
|
self._upload_storage_uri = value.rstrip("/") if value else None
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
|
def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
|
||||||
@ -44,9 +44,9 @@ def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||||
""" Manager for backend model objects """
|
"""Manager for backend model objects"""
|
||||||
|
|
||||||
_EMPTY_MODEL_ID = 'empty'
|
_EMPTY_MODEL_ID = "empty"
|
||||||
|
|
||||||
_local_model_to_id_uri = {}
|
_local_model_to_id_uri = {}
|
||||||
|
|
||||||
@ -54,8 +54,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
def model_id(self):
|
def model_id(self):
|
||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
def __init__(self, upload_storage_uri, cache_dir, model_id=None,
|
def __init__(
|
||||||
upload_storage_suffix='models', session=None, log=None):
|
self,
|
||||||
|
upload_storage_uri,
|
||||||
|
cache_dir,
|
||||||
|
model_id=None,
|
||||||
|
upload_storage_suffix="models",
|
||||||
|
session=None,
|
||||||
|
log=None
|
||||||
|
):
|
||||||
super(Model, self).__init__(id=model_id, session=session, log=log)
|
super(Model, self).__init__(id=model_id, session=session, log=log)
|
||||||
self._upload_storage_suffix = upload_storage_suffix
|
self._upload_storage_suffix = upload_storage_suffix
|
||||||
if model_id == self._EMPTY_MODEL_ID:
|
if model_id == self._EMPTY_MODEL_ID:
|
||||||
@ -71,7 +78,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def _reload(self):
|
def _reload(self):
|
||||||
""" Reload the model object """
|
"""Reload the model object"""
|
||||||
if self._offline_mode:
|
if self._offline_mode:
|
||||||
return models.Model()
|
return models.Model()
|
||||||
|
|
||||||
@ -80,11 +87,19 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
res = self.send(models.GetByIdRequest(model=self.id))
|
res = self.send(models.GetByIdRequest(model=self.id))
|
||||||
return res.response.model
|
return res.response.model
|
||||||
|
|
||||||
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
|
def _upload_model(
|
||||||
|
self, model_file, async_enable=False, target_filename=None, cb=None
|
||||||
|
):
|
||||||
if not self.upload_storage_uri:
|
if not self.upload_storage_uri:
|
||||||
raise ValueError('Model has no storage URI defined (nowhere to upload to)')
|
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
|
||||||
target_filename = target_filename or Path(model_file).name
|
target_filename = target_filename or Path(model_file).name
|
||||||
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename))
|
dest_path = "/".join(
|
||||||
|
(
|
||||||
|
self.upload_storage_uri,
|
||||||
|
self._upload_storage_suffix or ".",
|
||||||
|
target_filename,
|
||||||
|
)
|
||||||
|
)
|
||||||
result = StorageHelper.get(dest_path).upload(
|
result = StorageHelper.get(dest_path).upload(
|
||||||
src_path=model_file,
|
src_path=model_file,
|
||||||
dest_path=dest_path,
|
dest_path=dest_path,
|
||||||
@ -93,19 +108,23 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return_canonized=False
|
return_canonized=False
|
||||||
)
|
)
|
||||||
if async_enable:
|
if async_enable:
|
||||||
|
|
||||||
def msg(num_results):
|
def msg(num_results):
|
||||||
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
|
self.log.info(
|
||||||
|
"Waiting for previous model to upload (%d pending, %s)"
|
||||||
|
% (num_results, dest_path)
|
||||||
|
)
|
||||||
|
|
||||||
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
||||||
return dest_path
|
return dest_path
|
||||||
|
|
||||||
def _upload_callback(self, res, cb=None):
|
def _upload_callback(self, res, cb=None):
|
||||||
if res is None:
|
if res is None:
|
||||||
self.log.debug('Starting model upload')
|
self.log.debug("Starting model upload")
|
||||||
elif res is False:
|
elif res is False:
|
||||||
self.log.info('Failed model upload')
|
self.log.info("Failed model upload")
|
||||||
else:
|
else:
|
||||||
self.log.info('Completed model upload to {}'.format(res))
|
self.log.info("Completed model upload to {}".format(res))
|
||||||
if cb:
|
if cb:
|
||||||
cb(res)
|
cb(res)
|
||||||
|
|
||||||
@ -126,12 +145,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
:return: A proper design dictionary according to design parameter.
|
:return: A proper design dictionary according to design parameter.
|
||||||
"""
|
"""
|
||||||
if isinstance(design, dict):
|
if isinstance(design, dict):
|
||||||
if 'design' not in design:
|
if "design" not in design:
|
||||||
raise ValueError('design dictionary must have \'design\' key in it')
|
raise ValueError("design dictionary must have 'design' key in it")
|
||||||
|
|
||||||
return design
|
return design
|
||||||
|
|
||||||
return {'design': design if design else ''}
|
return {"design": design if design else ""}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _unwrap_design(design):
|
def _unwrap_design(design):
|
||||||
@ -153,23 +172,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
:return: The design string according to design parameter.
|
:return: The design string according to design parameter.
|
||||||
"""
|
"""
|
||||||
if not design:
|
if not design:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
if isinstance(design, six.string_types):
|
if isinstance(design, six.string_types):
|
||||||
return design
|
return design
|
||||||
|
|
||||||
if isinstance(design, dict):
|
if isinstance(design, dict):
|
||||||
if 'design' in design:
|
if "design" in design:
|
||||||
return design['design']
|
return design["design"]
|
||||||
|
|
||||||
return list(design.values())[0]
|
return list(design.values())[0]
|
||||||
|
|
||||||
raise ValueError('design must be a string or a dictionary with at least one value')
|
raise ValueError(
|
||||||
|
"design must be a string or a dictionary with at least one value"
|
||||||
|
)
|
||||||
|
|
||||||
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
|
def update(
|
||||||
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
|
self,
|
||||||
upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None):
|
model_file=None,
|
||||||
""" Update model weights file and various model properties """
|
design=None,
|
||||||
|
labels=None,
|
||||||
|
name=None,
|
||||||
|
comment=None,
|
||||||
|
tags=None,
|
||||||
|
task_id=None,
|
||||||
|
project_id=None,
|
||||||
|
parent_id=None,
|
||||||
|
uri=None,
|
||||||
|
framework=None,
|
||||||
|
upload_storage_uri=None,
|
||||||
|
target_filename=None,
|
||||||
|
iteration=None,
|
||||||
|
system_tags=None
|
||||||
|
):
|
||||||
|
"""Update model weights file and various model properties"""
|
||||||
|
|
||||||
if self.id is None:
|
if self.id is None:
|
||||||
if upload_storage_uri:
|
if upload_storage_uri:
|
||||||
@ -182,7 +218,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||||
|
|
||||||
# upload model file if needed and get uri
|
# upload model file if needed and get uri
|
||||||
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
|
uri = uri or (
|
||||||
|
self._upload_model(model_file, target_filename=target_filename)
|
||||||
|
if model_file
|
||||||
|
else self.data.uri
|
||||||
|
)
|
||||||
# update fields
|
# update fields
|
||||||
design = self._wrap_design(design) if design else self.data.design
|
design = self._wrap_design(design) if design else self.data.design
|
||||||
name = name or self.data.name
|
name = name or self.data.name
|
||||||
@ -192,7 +232,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
project = project_id or self.data.project
|
project = project_id or self.data.project
|
||||||
parent = parent_id or self.data.parent
|
parent = parent_id or self.data.parent
|
||||||
tags = tags or self.data.tags
|
tags = tags or self.data.tags
|
||||||
if Session.check_min_api_version('2.3'):
|
if Session.check_min_api_version("2.3"):
|
||||||
system_tags = system_tags or self.data.system_tags
|
system_tags = system_tags or self.data.system_tags
|
||||||
|
|
||||||
self._edit(
|
self._edit(
|
||||||
@ -210,33 +250,74 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
system_tags=system_tags,
|
system_tags=system_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
def edit(
|
||||||
uri=None, framework=None, iteration=None, system_tags=None):
|
self,
|
||||||
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
design=None,
|
||||||
uri=uri, framework=framework, iteration=iteration, system_tags=system_tags)
|
labels=None,
|
||||||
|
name=None,
|
||||||
|
comment=None,
|
||||||
|
tags=None,
|
||||||
|
uri=None,
|
||||||
|
framework=None,
|
||||||
|
iteration=None,
|
||||||
|
system_tags=None
|
||||||
|
):
|
||||||
|
return self._edit(
|
||||||
|
design=design,
|
||||||
|
labels=labels,
|
||||||
|
name=name,
|
||||||
|
comment=comment,
|
||||||
|
tags=tags,
|
||||||
|
uri=uri,
|
||||||
|
framework=framework,
|
||||||
|
iteration=iteration,
|
||||||
|
system_tags=system_tags,
|
||||||
|
)
|
||||||
|
|
||||||
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
def _edit(
|
||||||
uri=None, framework=None, iteration=None, system_tags=None, **extra):
|
self,
|
||||||
|
design=None,
|
||||||
|
labels=None,
|
||||||
|
name=None,
|
||||||
|
comment=None,
|
||||||
|
tags=None,
|
||||||
|
uri=None,
|
||||||
|
framework=None,
|
||||||
|
iteration=None,
|
||||||
|
system_tags=None,
|
||||||
|
**extra
|
||||||
|
):
|
||||||
def offline_store(**kwargs):
|
def offline_store(**kwargs):
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(self.data, k, v or getattr(self.data, k, None))
|
setattr(self.data, k, v or getattr(self.data, k, None))
|
||||||
return
|
return
|
||||||
if self._offline_mode:
|
|
||||||
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
|
||||||
uri=uri, framework=framework, iteration=iteration, **extra)
|
|
||||||
|
|
||||||
if Session.check_min_api_version('2.3'):
|
if self._offline_mode:
|
||||||
|
return offline_store(
|
||||||
|
design=design,
|
||||||
|
labels=labels,
|
||||||
|
name=name,
|
||||||
|
comment=comment,
|
||||||
|
tags=tags,
|
||||||
|
uri=uri,
|
||||||
|
framework=framework,
|
||||||
|
iteration=iteration,
|
||||||
|
**extra
|
||||||
|
)
|
||||||
|
|
||||||
|
if Session.check_min_api_version("2.3"):
|
||||||
if tags is not None:
|
if tags is not None:
|
||||||
extra.update({'tags': tags})
|
extra.update({"tags": tags})
|
||||||
if system_tags is not None:
|
if system_tags is not None:
|
||||||
extra.update({'system_tags': system_tags})
|
extra.update({"system_tags": system_tags})
|
||||||
elif tags is not None or system_tags is not None:
|
elif tags is not None or system_tags is not None:
|
||||||
if tags and system_tags:
|
if tags and system_tags:
|
||||||
system_tags = system_tags[:]
|
system_tags = system_tags[:]
|
||||||
system_tags += [t for t in tags if t not in system_tags]
|
system_tags += [t for t in tags if t not in system_tags]
|
||||||
extra.update({'system_tags': system_tags or tags or self.data.system_tags})
|
extra.update({"system_tags": system_tags or tags or self.data.system_tags})
|
||||||
|
|
||||||
self.send(models.EditRequest(
|
self.send(
|
||||||
|
models.EditRequest(
|
||||||
model=self.id,
|
model=self.id,
|
||||||
uri=uri,
|
uri=uri,
|
||||||
name=name,
|
name=name,
|
||||||
@ -246,23 +327,44 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
framework=framework,
|
framework=framework,
|
||||||
iteration=iteration,
|
iteration=iteration,
|
||||||
**extra
|
**extra
|
||||||
))
|
)
|
||||||
|
)
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None,
|
def update_and_upload(
|
||||||
tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False,
|
self,
|
||||||
target_filename=None, cb=None, iteration=None):
|
model_file,
|
||||||
""" Update the given model for a given task ID """
|
design=None,
|
||||||
|
labels=None,
|
||||||
|
name=None,
|
||||||
|
comment=None,
|
||||||
|
tags=None,
|
||||||
|
task_id=None,
|
||||||
|
project_id=None,
|
||||||
|
parent_id=None,
|
||||||
|
framework=None,
|
||||||
|
async_enable=False,
|
||||||
|
target_filename=None,
|
||||||
|
cb=None,
|
||||||
|
iteration=None
|
||||||
|
):
|
||||||
|
"""Update the given model for a given task ID"""
|
||||||
if async_enable:
|
if async_enable:
|
||||||
|
|
||||||
def callback(uploaded_uri):
|
def callback(uploaded_uri):
|
||||||
if uploaded_uri is None:
|
if uploaded_uri is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# If not successful, mark model as failed_uploading
|
# If not successful, mark model as failed_uploading
|
||||||
if uploaded_uri is False:
|
if uploaded_uri is False:
|
||||||
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
|
uploaded_uri = "{}/failed_uploading".format(
|
||||||
|
self._upload_storage_uri
|
||||||
|
)
|
||||||
|
|
||||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri)
|
Model._local_model_to_id_uri[str(model_file)] = (
|
||||||
|
self.model_id,
|
||||||
|
uploaded_uri,
|
||||||
|
)
|
||||||
|
|
||||||
self.update(
|
self.update(
|
||||||
uri=uploaded_uri,
|
uri=uploaded_uri,
|
||||||
@ -281,11 +383,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
if cb:
|
if cb:
|
||||||
cb(model_file)
|
cb(model_file)
|
||||||
|
|
||||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename,
|
uri = self._upload_model(
|
||||||
cb=callback)
|
model_file,
|
||||||
|
async_enable=async_enable,
|
||||||
|
target_filename=target_filename,
|
||||||
|
cb=callback,
|
||||||
|
)
|
||||||
return uri
|
return uri
|
||||||
else:
|
else:
|
||||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
uri = self._upload_model(
|
||||||
|
model_file, async_enable=async_enable, target_filename=target_filename
|
||||||
|
)
|
||||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||||
self.update(
|
self.update(
|
||||||
uri=uri,
|
uri=uri,
|
||||||
@ -302,7 +410,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
|
def update_for_task(
|
||||||
|
self, task_id, name=None, model_id=None, type_="output", iteration=None
|
||||||
|
):
|
||||||
if Session.check_min_api_version("2.13"):
|
if Session.check_min_api_version("2.13"):
|
||||||
req = tasks.AddOrUpdateModelRequest(
|
req = tasks.AddOrUpdateModelRequest(
|
||||||
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
||||||
@ -314,7 +424,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
# backwards compatibility, None
|
# backwards compatibility, None
|
||||||
req = None
|
req = None
|
||||||
else:
|
else:
|
||||||
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
|
raise ValueError(
|
||||||
|
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
|
||||||
|
)
|
||||||
|
|
||||||
if req:
|
if req:
|
||||||
self.send(req)
|
self.send(req)
|
||||||
@ -323,7 +435,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_design(self):
|
def model_design(self):
|
||||||
""" Get the model design. For now, this is stored as a single key in the design dict. """
|
"""Get the model design. For now, this is stored as a single key in the design dict."""
|
||||||
try:
|
try:
|
||||||
return self._unwrap_design(self.data.design)
|
return self._unwrap_design(self.data.design)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -364,7 +476,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self):
|
def tags(self):
|
||||||
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
|
return (
|
||||||
|
self.data.system_tags
|
||||||
|
if hasattr(self.data, "system_tags")
|
||||||
|
else self.data.tags
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task(self):
|
def task(self):
|
||||||
@ -394,8 +510,10 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||||
raise ValueError, otherwise return None on failure and output log warning.
|
raise ValueError, otherwise return None on failure and output log warning.
|
||||||
|
|
||||||
:param bool force_download: If True, the base artifact will be downloaded,
|
:param bool force_download: If True, the base artifact will be downloaded,
|
||||||
even if the artifact is already cached.
|
even if the artifact is already cached.
|
||||||
|
|
||||||
:return: a local path to a downloaded copy of the model
|
:return: a local path to a downloaded copy of the model
|
||||||
"""
|
"""
|
||||||
uri = self.data.uri
|
uri = self.data.uri
|
||||||
@ -403,21 +521,29 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# check if we already downloaded the file
|
# check if we already downloaded the file
|
||||||
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
downloaded_models = [
|
||||||
|
k
|
||||||
|
for k, (i, u) in Model._local_model_to_id_uri.items()
|
||||||
|
if i == self.id and u == uri
|
||||||
|
]
|
||||||
for dl_file in downloaded_models:
|
for dl_file in downloaded_models:
|
||||||
if Path(dl_file).exists() and not force_download:
|
if Path(dl_file).exists() and not force_download:
|
||||||
return dl_file
|
return dl_file
|
||||||
# remove non existing model file
|
# remove non existing model file
|
||||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||||
|
|
||||||
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download)
|
local_download = StorageManager.get_local_copy(
|
||||||
|
uri, extract_archive=False, force_download=force_download
|
||||||
|
)
|
||||||
|
|
||||||
# save local model, so we can later query what was the original one
|
# save local model, so we can later query what was the original one
|
||||||
if local_download is not None:
|
if local_download is not None:
|
||||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||||
elif raise_on_error:
|
elif raise_on_error:
|
||||||
raise ValueError("Could not retrieve a local copy of model weights {}, "
|
raise ValueError(
|
||||||
"failed downloading {}".format(self.model_id, uri))
|
"Could not retrieve a local copy of model weights {}, "
|
||||||
|
"failed downloading {}".format(self.model_id, uri)
|
||||||
|
)
|
||||||
|
|
||||||
return local_download
|
return local_download
|
||||||
|
|
||||||
@ -426,9 +552,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return self._cache_dir
|
return self._cache_dir
|
||||||
|
|
||||||
def save_model_design_file(self):
|
def save_model_design_file(self):
|
||||||
""" Download model description file into a local file in our cache_dir """
|
"""Download model description file into a local file in our cache_dir"""
|
||||||
design = self.model_design
|
design = self.model_design
|
||||||
filename = self.data.name + '.txt'
|
filename = self.data.name + ".txt"
|
||||||
p = Path(self.cache_dir) / filename
|
p = Path(self.cache_dir) / filename
|
||||||
# we always write the original model design to file, to prevent any mishaps
|
# we always write the original model design to file, to prevent any mishaps
|
||||||
# if p.is_file():
|
# if p.is_file():
|
||||||
@ -438,11 +564,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
return str(p)
|
return str(p)
|
||||||
|
|
||||||
def get_model_package(self):
|
def get_model_package(self):
|
||||||
""" Get a named tuple containing the model's weights and design """
|
"""Get a named tuple containing the model's weights and design"""
|
||||||
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
|
return ModelPackage(
|
||||||
|
weights=self.download_model_weights(), design=self.save_model_design_file()
|
||||||
|
)
|
||||||
|
|
||||||
def get_model_design(self):
|
def get_model_design(self):
|
||||||
""" Get model description (text) """
|
"""Get model description (text)"""
|
||||||
return self.model_design
|
return self.model_design
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -465,8 +593,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
data = self.data
|
data = self.data
|
||||||
assert isinstance(data, models.Model)
|
assert isinstance(data, models.Model)
|
||||||
parent = self.id if child else None
|
parent = self.id if child else None
|
||||||
extra = {'system_tags': tags or data.system_tags} \
|
extra = (
|
||||||
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
|
{"system_tags": tags or data.system_tags}
|
||||||
|
if Session.check_min_api_version("2.3")
|
||||||
|
else {"tags": tags or data.tags}
|
||||||
|
)
|
||||||
req = models.CreateRequest(
|
req = models.CreateRequest(
|
||||||
uri=data.uri,
|
uri=data.uri,
|
||||||
name=name,
|
name=name,
|
||||||
@ -485,8 +616,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
|||||||
|
|
||||||
def _create_empty_model(self, upload_storage_uri=None, project_id=None):
|
def _create_empty_model(self, upload_storage_uri=None, project_id=None):
|
||||||
upload_storage_uri = upload_storage_uri or self.upload_storage_uri
|
upload_storage_uri = upload_storage_uri or self.upload_storage_uri
|
||||||
name = make_message('Anonymous model %(time)s')
|
name = make_message("Anonymous model %(time)s")
|
||||||
uri = '{}/uploading_file'.format(upload_storage_uri or 'file://')
|
uri = "{}/uploading_file".format(upload_storage_uri or "file://")
|
||||||
req = models.CreateRequest(uri=uri, name=name, labels={}, project=project_id)
|
req = models.CreateRequest(uri=uri, name=name, labels={}, project=project_id)
|
||||||
res = self.send(req)
|
res = self.send(req)
|
||||||
if not res:
|
if not res:
|
||||||
|
@ -652,7 +652,6 @@ if __name__ == '__main__':
|
|||||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||||
a_function, sanitize_function=_sanitize_function
|
a_function, sanitize_function=_sanitize_function
|
||||||
)
|
)
|
||||||
|
|
||||||
# add helper functions on top.
|
# add helper functions on top.
|
||||||
for f in (helper_functions or []):
|
for f in (helper_functions or []):
|
||||||
helper_function_source, _ = CreateFromFunction.__extract_function_information(
|
helper_function_source, _ = CreateFromFunction.__extract_function_information(
|
||||||
@ -665,7 +664,6 @@ if __name__ == '__main__':
|
|||||||
if artifact_serialization_function
|
if artifact_serialization_function
|
||||||
else ("", "None")
|
else ("", "None")
|
||||||
)
|
)
|
||||||
|
|
||||||
artifact_deserialization_function_source, artifact_deserialization_function_name = (
|
artifact_deserialization_function_source, artifact_deserialization_function_name = (
|
||||||
CreateFromFunction.__extract_function_information(artifact_deserialization_function)
|
CreateFromFunction.__extract_function_information(artifact_deserialization_function)
|
||||||
if artifact_deserialization_function
|
if artifact_deserialization_function
|
||||||
@ -833,7 +831,5 @@ if __name__ == '__main__':
|
|||||||
function_source = inspect.getsource(function)
|
function_source = inspect.getsource(function)
|
||||||
if sanitize_function:
|
if sanitize_function:
|
||||||
function_source = sanitize_function(function_source)
|
function_source = sanitize_function(function_source)
|
||||||
|
|
||||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
||||||
|
|
||||||
return function_source, function_name
|
return function_source, function_name
|
@ -255,10 +255,14 @@ class ScriptRequirements(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _remove_package_versions(installed_pkgs, package_names_to_remove_version):
|
def _remove_package_versions(installed_pkgs, package_names_to_remove_version):
|
||||||
installed_pkgs = {k: (v[0], None if str(k) in package_names_to_remove_version else v[1])
|
def _internal(_installed_pkgs):
|
||||||
for k, v in installed_pkgs.items()}
|
return {
|
||||||
|
k: (v[0], None if str(k) in package_names_to_remove_version else v[1])
|
||||||
|
if not isinstance(v, dict) else _internal(v)
|
||||||
|
for k, v in _installed_pkgs.items()
|
||||||
|
}
|
||||||
|
|
||||||
return installed_pkgs
|
return _internal(installed_pkgs)
|
||||||
|
|
||||||
|
|
||||||
class _JupyterObserver(object):
|
class _JupyterObserver(object):
|
||||||
@ -781,6 +785,7 @@ class ScriptInfo(object):
|
|||||||
try:
|
try:
|
||||||
# we expect to find boto3 in the sagemaker env
|
# we expect to find boto3 in the sagemaker env
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
with open(cls._sagemaker_metadata_path) as f:
|
with open(cls._sagemaker_metadata_path) as f:
|
||||||
notebook_data = json.load(f)
|
notebook_data = json.load(f)
|
||||||
client = boto3.client("sagemaker")
|
client = boto3.client("sagemaker")
|
||||||
@ -799,7 +804,6 @@ class ScriptInfo(object):
|
|||||||
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
|
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
|
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
|
||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -700,7 +700,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
|
|
||||||
the text will not be printed, because the Python process is immediately terminated.
|
the text will not be printed, because the Python process is immediately terminated.
|
||||||
|
|
||||||
:param bool ignore_errors: If True default), ignore any errors raised
|
:param bool ignore_errors: If True (default), ignore any errors raised
|
||||||
:param bool force: If True, the task status will be changed to `stopped` regardless of the current Task state.
|
:param bool force: If True, the task status will be changed to `stopped` regardless of the current Task state.
|
||||||
:param str status_message: Optional, add status change message to the stop request.
|
:param str status_message: Optional, add status change message to the stop request.
|
||||||
This message will be stored as status_message on the Task's info panel
|
This message will be stored as status_message on the Task's info panel
|
||||||
@ -718,11 +718,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
tasks.CompletedRequest(
|
tasks.CompletedRequest(
|
||||||
self.id, status_reason='completed', status_message=status_message, force=force),
|
self.id, status_reason='completed', status_message=status_message, force=force),
|
||||||
ignore_errors=ignore_errors)
|
ignore_errors=ignore_errors)
|
||||||
|
|
||||||
if self._get_runtime_properties().get("_publish_on_complete"):
|
if self._get_runtime_properties().get("_publish_on_complete"):
|
||||||
self.send(
|
self.send(
|
||||||
tasks.PublishRequest(
|
tasks.PublishRequest(
|
||||||
self.id, status_reason='completed', status_message=status_message, force=force),
|
self.id, status_reason='completed', status_message=status_message, force=force),
|
||||||
ignore_errors=ignore_errors)
|
ignore_errors=ignore_errors)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
return self.send(
|
return self.send(
|
||||||
tasks.StoppedRequest(self.id, status_reason='completed', status_message=status_message, force=force),
|
tasks.StoppedRequest(self.id, status_reason='completed', status_message=status_message, force=force),
|
||||||
@ -2387,7 +2389,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _get_runtime_properties(self):
|
def _get_runtime_properties(self):
|
||||||
# type: () -> Mapping[str, str]
|
# type: () -> Dict[str, str]
|
||||||
if not Session.check_min_api_version('2.13'):
|
if not Session.check_min_api_version('2.13'):
|
||||||
return dict()
|
return dict()
|
||||||
return dict(**self.data.runtime) if self.data.runtime else dict()
|
return dict(**self.data.runtime) if self.data.runtime else dict()
|
||||||
|
@ -93,7 +93,7 @@ def main():
|
|||||||
# Take the credentials in raw form or from api section
|
# Take the credentials in raw form or from api section
|
||||||
credentials = get_parsed_field(parsed, ["credentials"])
|
credentials = get_parsed_field(parsed, ["credentials"])
|
||||||
api_server = get_parsed_field(parsed, ["api_server", "host"])
|
api_server = get_parsed_field(parsed, ["api_server", "host"])
|
||||||
web_server = get_parsed_field(parsed, ["web_server"])
|
web_server = get_parsed_field(parsed, ["web_server"]) # TODO: if previous fails, this will fail too
|
||||||
files_server = get_parsed_field(parsed, ["files_server"])
|
files_server = get_parsed_field(parsed, ["files_server"])
|
||||||
except Exception:
|
except Exception:
|
||||||
credentials = credentials or None
|
credentials = credentials or None
|
||||||
|
@ -17,7 +17,6 @@ from attr import attrs, attrib
|
|||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
from .. import Task, StorageManager, Logger
|
from .. import Task, StorageManager, Logger
|
||||||
from ..backend_api.session.client import APIClient
|
|
||||||
from ..backend_api import Session
|
from ..backend_api import Session
|
||||||
from ..backend_interface.task.development.worker import DevWorker
|
from ..backend_interface.task.development.worker import DevWorker
|
||||||
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project
|
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project
|
||||||
|
@ -359,6 +359,7 @@ class Logger(object):
|
|||||||
iteration=0,
|
iteration=0,
|
||||||
table_plot=df,
|
table_plot=df,
|
||||||
extra_data={'columnwidth': [2., 1., 1., 1.]})
|
extra_data={'columnwidth': [2., 1., 1., 1.]})
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mutually_exclusive(
|
mutually_exclusive(
|
||||||
UsageError, _check_none=True,
|
UsageError, _check_none=True,
|
||||||
|
555
clearml/model.py
555
clearml/model.py
File diff suppressed because it is too large
Load Diff
@ -1177,8 +1177,7 @@ class StorageHelper(object):
|
|||||||
|
|
||||||
def _do_async_upload(self, data):
|
def _do_async_upload(self, data):
|
||||||
assert isinstance(data, self._UploadData)
|
assert isinstance(data, self._UploadData)
|
||||||
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback,
|
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, verbose=True, retries=data.retries, return_canonized=data.return_canonized)
|
||||||
verbose=True, retries=data.retries, return_canonized=data.return_canonized)
|
|
||||||
|
|
||||||
def _upload_from_file(self, local_path, dest_path, extra=None):
|
def _upload_from_file(self, local_path, dest_path, extra=None):
|
||||||
if not hasattr(self._driver, 'upload_object'):
|
if not hasattr(self._driver, 'upload_object'):
|
||||||
@ -1473,7 +1472,6 @@ class _HttpDriver(_Driver):
|
|||||||
try:
|
try:
|
||||||
container = self.get_container(container_name)
|
container = self.get_container(container_name)
|
||||||
url = container_name + object_name
|
url = container_name + object_name
|
||||||
|
|
||||||
return container.session.head(url, allow_redirects=True, headers=container.get_headers(url)).ok
|
return container.session.head(url, allow_redirects=True, headers=container.get_headers(url)).ok
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
@ -7,7 +6,6 @@ from random import random
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
from six.moves.urllib.parse import urlparse
|
|
||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
@ -304,8 +302,8 @@ class StorageManager(object):
|
|||||||
if not local_folder:
|
if not local_folder:
|
||||||
local_folder = CacheManager.get_cache_manager().get_cache_folder()
|
local_folder = CacheManager.get_cache_manager().get_cache_folder()
|
||||||
local_path = str(Path(local_folder).expanduser().absolute() / bucket_path)
|
local_path = str(Path(local_folder).expanduser().absolute() / bucket_path)
|
||||||
|
|
||||||
helper = StorageHelper.get(remote_url)
|
helper = StorageHelper.get(remote_url)
|
||||||
|
|
||||||
return helper.download_to_file(
|
return helper.download_to_file(
|
||||||
remote_url,
|
remote_url,
|
||||||
local_path,
|
local_path,
|
||||||
|
@ -1988,7 +1988,9 @@ class Task(_Task):
|
|||||||
corresponding to debug sample's file name in the UI, also known as variant
|
corresponding to debug sample's file name in the UI, also known as variant
|
||||||
:param int n_last_iterations: How many debug samples iterations to fetch in reverse chronological order.
|
:param int n_last_iterations: How many debug samples iterations to fetch in reverse chronological order.
|
||||||
Leave empty to get all debug samples.
|
Leave empty to get all debug samples.
|
||||||
|
|
||||||
:raise: TypeError if `n_last_iterations` is explicitly set to anything other than a positive integer value
|
:raise: TypeError if `n_last_iterations` is explicitly set to anything other than a positive integer value
|
||||||
|
|
||||||
:return: A list of `dict`s, each dictionary containing the debug sample's URL and other metadata.
|
:return: A list of `dict`s, each dictionary containing the debug sample's URL and other metadata.
|
||||||
The URLs can be passed to :meth:`StorageManager.get_local_copy` to fetch local copies of debug samples.
|
The URLs can be passed to :meth:`StorageManager.get_local_copy` to fetch local copies of debug samples.
|
||||||
"""
|
"""
|
||||||
@ -2021,11 +2023,14 @@ class Task(_Task):
|
|||||||
|
|
||||||
def _get_debug_samples(self, title, series, n_last_iterations=None):
|
def _get_debug_samples(self, title, series, n_last_iterations=None):
|
||||||
response = self._send_debug_image_request(title, series, n_last_iterations)
|
response = self._send_debug_image_request(title, series, n_last_iterations)
|
||||||
|
|
||||||
debug_samples = []
|
debug_samples = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
scroll_id = response.response.scroll_id
|
scroll_id = response.response_data.get("scroll_id", None)
|
||||||
for metric_resp in response.response.metrics:
|
|
||||||
iterations_events = [iteration["events"] for iteration in metric_resp.iterations] # type: List[List[dict]]
|
for metric_resp in response.response_data.get("metrics", []):
|
||||||
|
iterations_events = [iteration["events"] for iteration in metric_resp.get("iterations", [])] # type: List[List[dict]]
|
||||||
flattened_events = (event
|
flattened_events = (event
|
||||||
for single_iter_events in iterations_events
|
for single_iter_events in iterations_events
|
||||||
for event in single_iter_events)
|
for event in single_iter_events)
|
||||||
@ -2037,8 +2042,8 @@ class Task(_Task):
|
|||||||
|
|
||||||
if (len(debug_samples) == n_last_iterations
|
if (len(debug_samples) == n_last_iterations
|
||||||
or all(
|
or all(
|
||||||
len(metric_resp.iterations) == 0
|
len(metric_resp.get("iterations", [])) == 0
|
||||||
for metric_resp in response.response.metrics)):
|
for metric_resp in response.response_data.get("metrics", []))):
|
||||||
break
|
break
|
||||||
|
|
||||||
return debug_samples
|
return debug_samples
|
||||||
@ -2877,13 +2882,11 @@ class Task(_Task):
|
|||||||
Set offline mode, where all data and logs are stored into local folder, for later transmission
|
Set offline mode, where all data and logs are stored into local folder, for later transmission
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
`Task.set_offline` can't move the same task from offline to online, nor can it be applied before `Task.create`.
|
`Task.set_offline` can't move the same task from offline to online, nor can it be applied before `Task.create`.
|
||||||
See below an example of **incorect** usage of `Task.set_offline`:
|
See below an example of **incorect** usage of `Task.set_offline`:
|
||||||
|
|
||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
|
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
|
|
||||||
Task.set_offline(True)
|
Task.set_offline(True)
|
||||||
task = Task.create(project_name='DEBUG', task_name="offline")
|
task = Task.create(project_name='DEBUG', task_name="offline")
|
||||||
# ^^^ an error or warning is emitted, telling us that `Task.set_offline(True)`
|
# ^^^ an error or warning is emitted, telling us that `Task.set_offline(True)`
|
||||||
@ -2891,23 +2894,25 @@ class Task(_Task):
|
|||||||
Task.set_offline(False)
|
Task.set_offline(False)
|
||||||
# ^^^ an error or warning is emitted, telling us that running `Task.set_offline(False)`
|
# ^^^ an error or warning is emitted, telling us that running `Task.set_offline(False)`
|
||||||
# while the current task is not closed is not something we support
|
# while the current task is not closed is not something we support
|
||||||
|
|
||||||
data = task.export_task()
|
data = task.export_task()
|
||||||
|
|
||||||
imported_task = Task.import_task(task_data=data)
|
imported_task = Task.import_task(task_data=data)
|
||||||
|
|
||||||
The correct way to use `Task.set_offline` can be seen in the following example:
|
The correct way to use `Task.set_offline` can be seen in the following example:
|
||||||
|
|
||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
|
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
|
|
||||||
Task.set_offline(True)
|
Task.set_offline(True)
|
||||||
task = Task.init(project_name='DEBUG', task_name="offline")
|
task = Task.init(project_name='DEBUG', task_name="offline")
|
||||||
task.upload_artifact("large_artifact", "test_strign")
|
task.upload_artifact("large_artifact", "test_strign")
|
||||||
task.close()
|
task.close()
|
||||||
Task.set_offline(False)
|
Task.set_offline(False)
|
||||||
|
|
||||||
imported_task = Task.import_offline_session(task.get_offline_mode_folder())
|
imported_task = Task.import_offline_session(task.get_offline_mode_folder())
|
||||||
|
|
||||||
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
|
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode:
|
if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode:
|
||||||
@ -2932,6 +2937,7 @@ class Task(_Task):
|
|||||||
# type: () -> bool
|
# type: () -> bool
|
||||||
"""
|
"""
|
||||||
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
|
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
|
||||||
|
|
||||||
:return: boolean offline-mode state
|
:return: boolean offline-mode state
|
||||||
"""
|
"""
|
||||||
return cls._offline_mode
|
return cls._offline_mode
|
||||||
@ -3542,11 +3548,9 @@ class Task(_Task):
|
|||||||
def _check_keys(dict_, warning_sent=False):
|
def _check_keys(dict_, warning_sent=False):
|
||||||
if warning_sent:
|
if warning_sent:
|
||||||
return
|
return
|
||||||
|
|
||||||
for k, v in dict_.items():
|
for k, v in dict_.items():
|
||||||
if warning_sent:
|
if warning_sent:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(k, str):
|
if not isinstance(k, str):
|
||||||
getLogger().warning(
|
getLogger().warning(
|
||||||
"Unsupported key of type '{}' found when connecting dictionary. It will be converted to str".format(
|
"Unsupported key of type '{}' found when connecting dictionary. It will be converted to str".format(
|
||||||
@ -3554,12 +3558,10 @@ class Task(_Task):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
warning_sent = True
|
warning_sent = True
|
||||||
|
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
_check_keys(v, warning_sent)
|
_check_keys(v, warning_sent)
|
||||||
|
|
||||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
|
|
||||||
_check_keys(dictionary)
|
_check_keys(dictionary)
|
||||||
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
|
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
|
||||||
self._arguments.copy_from_dict(flat_dict, prefix=name)
|
self._arguments.copy_from_dict(flat_dict, prefix=name)
|
||||||
@ -3909,7 +3911,6 @@ class Task(_Task):
|
|||||||
try:
|
try:
|
||||||
# make sure the state of the offline data is saved
|
# make sure the state of the offline data is saved
|
||||||
self._edit()
|
self._edit()
|
||||||
|
|
||||||
# create zip file
|
# create zip file
|
||||||
offline_folder = self.get_offline_mode_folder()
|
offline_folder = self.get_offline_mode_folder()
|
||||||
zip_file = offline_folder.as_posix() + '.zip'
|
zip_file = offline_folder.as_posix() + '.zip'
|
||||||
|
@ -223,11 +223,18 @@
|
|||||||
"toc_visible": true
|
"toc_visible": true
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"name": "python"
|
"name": "python",
|
||||||
|
"version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "8b483fbf9fa60c6c6195634afd5159f586a30c5c6a9d31fa17f93a17f02fdc40"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -451,4 +451,5 @@
|
|||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 0
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ from enum import Enum
|
|||||||
from clearml import Task
|
from clearml import Task
|
||||||
from clearml.task_parameters import TaskParameters, param, percent_param
|
from clearml.task_parameters import TaskParameters, param, percent_param
|
||||||
|
|
||||||
|
|
||||||
# Connecting ClearML with the current process,
|
# Connecting ClearML with the current process,
|
||||||
# from here on everything is logged automatically
|
# from here on everything is logged automatically
|
||||||
task = Task.init(project_name='FirstTrial', task_name='first_trial')
|
task = Task.init(project_name='FirstTrial', task_name='first_trial')
|
||||||
@ -43,7 +44,6 @@ class IntEnumClass(Enum):
|
|||||||
C = 1
|
C = 1
|
||||||
D = 2
|
D = 2
|
||||||
|
|
||||||
|
|
||||||
parameters = {
|
parameters = {
|
||||||
'list': [1, 2, 3],
|
'list': [1, 2, 3],
|
||||||
'dict': {'a': 1, 'b': 2},
|
'dict': {'a': 1, 'b': 2},
|
Loading…
Reference in New Issue
Block a user