mirror of
https://github.com/clearml/clearml
synced 2025-04-15 04:52:20 +00:00
Refactor code
This commit is contained in:
parent
ecf6a4df2a
commit
4cd8857c0d
.gitignore
clearml
docs/tutorials
examples/reporting
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,8 +11,8 @@ build/
|
||||
dist/
|
||||
*.egg-info
|
||||
.env
|
||||
venv/
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# example data
|
||||
examples/runs/
|
||||
|
@ -218,7 +218,6 @@ class PipelineController(object):
|
||||
def serialize(obj):
|
||||
import dill
|
||||
return dill.dumps(obj)
|
||||
|
||||
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
|
||||
which represents the serialized object. This function should return the deserialized object.
|
||||
All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
|
||||
@ -1157,6 +1156,7 @@ class PipelineController(object):
|
||||
# type: (bool, str) -> bool
|
||||
"""
|
||||
Evaluate whether or not the pipeline is successful
|
||||
|
||||
:param fail_on_step_fail: If True (default), evaluate the pipeline steps' status to assess if the pipeline
|
||||
is successful. If False, only evaluate the controller
|
||||
:param fail_condition: Must be one of the following: 'all' (default), 'failed' or 'aborted'. If 'failed', this
|
||||
@ -1175,18 +1175,14 @@ class PipelineController(object):
|
||||
success_status = [Task.TaskStatusEnum.completed, Task.TaskStatusEnum.failed]
|
||||
else:
|
||||
raise UsageError("fail_condition needs to be one of the following: 'all', 'failed', 'aborted'")
|
||||
|
||||
if self._task.status not in success_status:
|
||||
return False
|
||||
|
||||
if not fail_on_step_fail:
|
||||
return True
|
||||
|
||||
self._update_nodes_status()
|
||||
for node in self._nodes.values():
|
||||
if node.status not in success_status:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def elapsed(self):
|
||||
|
@ -139,7 +139,7 @@ class TaskTrigger(BaseTrigger):
|
||||
raise ValueError("You must provide metric/variant/threshold")
|
||||
valid_status = [str(s) for s in Task.TaskStatusEnum]
|
||||
if self.on_status and not all(s in valid_status for s in self.on_status):
|
||||
raise ValueError("You on_status contains invalid status value: {}".format(self.on_status))
|
||||
raise ValueError("Your on_status contains invalid status value: {}".format(self.on_status))
|
||||
valid_signs = ['min', 'minimum', 'max', 'maximum']
|
||||
if self.value_sign and self.value_sign not in valid_signs:
|
||||
raise ValueError("Invalid value_sign `{}`, valid options are: {}".format(self.value_sign, valid_signs))
|
||||
|
@ -2,7 +2,6 @@
|
||||
auth service
|
||||
|
||||
This service provides authentication management and authorization
|
||||
|
||||
validation for the entire system.
|
||||
"""
|
||||
import six
|
||||
|
@ -40,8 +40,10 @@ for a very long time for a non-responding or mis-configured server
|
||||
"""
|
||||
ENV_API_EXTRA_RETRY_CODES = EnvEntry("CLEARML_API_EXTRA_RETRY_CODES")
|
||||
|
||||
|
||||
ENV_FORCE_MAX_API_VERSION = EnvEntry("CLEARML_FORCE_MAX_API_VERSION", type=str)
|
||||
|
||||
|
||||
class MissingConfigError(ValueError):
|
||||
def __init__(self, message=None):
|
||||
if message is None:
|
||||
|
@ -94,7 +94,6 @@ class CompoundRequest(Request):
|
||||
if self._item_prop_name in dict_properties:
|
||||
del dict_properties[self._item_prop_name]
|
||||
dict_.update(dict_properties)
|
||||
|
||||
return dict_
|
||||
|
||||
def validate(self):
|
||||
|
@ -134,6 +134,7 @@ class Session(TokenManager):
|
||||
**kwargs
|
||||
):
|
||||
self.__class__._sessions_weakrefs.append(weakref.ref(self))
|
||||
|
||||
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
||||
self._logger = logger
|
||||
if self._verbose and not self._logger:
|
||||
@ -147,7 +148,6 @@ class Session(TokenManager):
|
||||
self.__init_host = host
|
||||
self.__init_http_retries_config = http_retries_config
|
||||
self.__token_manager_kwargs = kwargs
|
||||
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
@ -162,21 +162,21 @@ class Session(TokenManager):
|
||||
|
||||
self._ssl_error_count_verbosity = self.config.get(
|
||||
"api.ssl_error_count_verbosity", self._ssl_error_count_verbosity)
|
||||
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
|
||||
|
||||
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
|
||||
if not self.__host:
|
||||
raise ValueError("ClearML host was not set, check your configuration file or environment variable")
|
||||
|
||||
self.__host = self.__host.strip("/")
|
||||
self.__http_retries_config = self.__init_http_retries_config or self.config.get(
|
||||
"api.http.retries", ConfigTree()).as_plain_ordered_dict()
|
||||
|
||||
self.__http_retries_config["status_forcelist"] = self._get_retry_codes()
|
||||
self.__http_retries_config["config"] = self.config
|
||||
self.__http_session = get_http_session_with_retry(**self.__http_retries_config)
|
||||
self.__http_session.write_timeout = self._write_session_timeout
|
||||
self.__http_session.request_size_threshold = self._write_session_data_size
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size", None)
|
||||
if not self.__max_req_size:
|
||||
raise ValueError("missing max request size")
|
||||
|
||||
@ -186,7 +186,6 @@ class Session(TokenManager):
|
||||
req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None)
|
||||
self.__auth_token = None
|
||||
self._update_default_api_method()
|
||||
|
||||
if ENV_AUTH_TOKEN.get():
|
||||
self.__access_key = self.__secret_key = None
|
||||
self.__auth_token = ENV_AUTH_TOKEN.get()
|
||||
@ -203,9 +202,11 @@ class Session(TokenManager):
|
||||
|
||||
if not self.secret_key and not self.access_key and not self.__auth_token:
|
||||
raise MissingConfigError()
|
||||
|
||||
super(Session, self).__init__(
|
||||
**self.__token_manager_kwargs,
|
||||
req_token_expiration_sec=req_token_expiration_sec,
|
||||
token_expiration_threshold_sec=token_expiration_threshold_sec,
|
||||
req_token_expiration_sec=req_token_expiration_sec
|
||||
)
|
||||
self.refresh_token()
|
||||
|
||||
@ -633,6 +634,7 @@ class Session(TokenManager):
|
||||
|
||||
return call_result
|
||||
|
||||
@classmethod
|
||||
def _make_all_sessions_go_online(cls):
|
||||
for active_session in cls._get_all_active_sessions():
|
||||
# noinspection PyProtectedMember
|
||||
@ -647,7 +649,6 @@ class Session(TokenManager):
|
||||
if session:
|
||||
active_sessions.append(session)
|
||||
new_sessions_weakrefs.append(session_weakref)
|
||||
|
||||
cls._sessions_weakrefs = session_weakref
|
||||
return active_sessions
|
||||
|
||||
|
@ -7,6 +7,7 @@ from time import time
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import events as api_events
|
||||
from ..base import InterfaceBase
|
||||
from ...config import config, deferred_config
|
||||
|
@ -271,9 +271,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
self._for_model = for_model
|
||||
flush_threshold = config.get("development.worker.report_event_flush_threshold", 100)
|
||||
self._report_service = BackgroundReportService(
|
||||
task=task, async_enable=async_enable, metrics=metrics,
|
||||
task=task,
|
||||
async_enable=async_enable,
|
||||
metrics=metrics,
|
||||
flush_frequency=self._flush_frequency,
|
||||
flush_threshold=flush_threshold, for_model=for_model)
|
||||
flush_threshold=flush_threshold,
|
||||
for_model=for_model,
|
||||
)
|
||||
self._report_service.start()
|
||||
|
||||
def _set_storage_uri(self, value):
|
||||
@ -355,8 +359,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
"""
|
||||
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value,
|
||||
iter=iter)
|
||||
ev = ScalarEvent(
|
||||
metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series),
|
||||
value=value,
|
||||
iter=iter
|
||||
)
|
||||
self._report(ev)
|
||||
|
||||
def report_vector(self, title, series, values, iter):
|
||||
@ -457,8 +465,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
elif not isinstance(plot, six.string_types):
|
||||
raise ValueError('Plot should be a string or a dict')
|
||||
|
||||
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series),
|
||||
plot_str=plot, iter=iter)
|
||||
ev = PlotEvent(
|
||||
metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series),
|
||||
plot_str=plot,
|
||||
iter=iter
|
||||
)
|
||||
self._report(ev)
|
||||
|
||||
def report_image(self, title, series, src, iter):
|
||||
|
@ -12,7 +12,7 @@ from ..storage import StorageManager
|
||||
from ..storage.helper import StorageHelper
|
||||
from ..utilities.async_manager import AsyncManagerMixin
|
||||
|
||||
ModelPackage = namedtuple('ModelPackage', 'weights design')
|
||||
ModelPackage = namedtuple("ModelPackage", "weights design")
|
||||
|
||||
|
||||
class ModelDoesNotExistError(Exception):
|
||||
@ -22,12 +22,12 @@ class ModelDoesNotExistError(Exception):
|
||||
class _StorageUriMixin(object):
|
||||
@property
|
||||
def upload_storage_uri(self):
|
||||
""" A URI into which models are uploaded """
|
||||
"""A URI into which models are uploaded"""
|
||||
return self._upload_storage_uri
|
||||
|
||||
@upload_storage_uri.setter
|
||||
def upload_storage_uri(self, value):
|
||||
self._upload_storage_uri = value.rstrip('/') if value else None
|
||||
self._upload_storage_uri = value.rstrip("/") if value else None
|
||||
|
||||
|
||||
def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
|
||||
@ -44,9 +44,9 @@ def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
|
||||
|
||||
|
||||
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
""" Manager for backend model objects """
|
||||
"""Manager for backend model objects"""
|
||||
|
||||
_EMPTY_MODEL_ID = 'empty'
|
||||
_EMPTY_MODEL_ID = "empty"
|
||||
|
||||
_local_model_to_id_uri = {}
|
||||
|
||||
@ -54,8 +54,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
def model_id(self):
|
||||
return self.id
|
||||
|
||||
def __init__(self, upload_storage_uri, cache_dir, model_id=None,
|
||||
upload_storage_suffix='models', session=None, log=None):
|
||||
def __init__(
|
||||
self,
|
||||
upload_storage_uri,
|
||||
cache_dir,
|
||||
model_id=None,
|
||||
upload_storage_suffix="models",
|
||||
session=None,
|
||||
log=None
|
||||
):
|
||||
super(Model, self).__init__(id=model_id, session=session, log=log)
|
||||
self._upload_storage_suffix = upload_storage_suffix
|
||||
if model_id == self._EMPTY_MODEL_ID:
|
||||
@ -71,7 +78,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
self.reload()
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the model object """
|
||||
"""Reload the model object"""
|
||||
if self._offline_mode:
|
||||
return models.Model()
|
||||
|
||||
@ -80,11 +87,19 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
res = self.send(models.GetByIdRequest(model=self.id))
|
||||
return res.response.model
|
||||
|
||||
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
|
||||
def _upload_model(
|
||||
self, model_file, async_enable=False, target_filename=None, cb=None
|
||||
):
|
||||
if not self.upload_storage_uri:
|
||||
raise ValueError('Model has no storage URI defined (nowhere to upload to)')
|
||||
raise ValueError("Model has no storage URI defined (nowhere to upload to)")
|
||||
target_filename = target_filename or Path(model_file).name
|
||||
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename))
|
||||
dest_path = "/".join(
|
||||
(
|
||||
self.upload_storage_uri,
|
||||
self._upload_storage_suffix or ".",
|
||||
target_filename,
|
||||
)
|
||||
)
|
||||
result = StorageHelper.get(dest_path).upload(
|
||||
src_path=model_file,
|
||||
dest_path=dest_path,
|
||||
@ -93,19 +108,23 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return_canonized=False
|
||||
)
|
||||
if async_enable:
|
||||
|
||||
def msg(num_results):
|
||||
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
|
||||
self.log.info(
|
||||
"Waiting for previous model to upload (%d pending, %s)"
|
||||
% (num_results, dest_path)
|
||||
)
|
||||
|
||||
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
||||
return dest_path
|
||||
|
||||
def _upload_callback(self, res, cb=None):
|
||||
if res is None:
|
||||
self.log.debug('Starting model upload')
|
||||
self.log.debug("Starting model upload")
|
||||
elif res is False:
|
||||
self.log.info('Failed model upload')
|
||||
self.log.info("Failed model upload")
|
||||
else:
|
||||
self.log.info('Completed model upload to {}'.format(res))
|
||||
self.log.info("Completed model upload to {}".format(res))
|
||||
if cb:
|
||||
cb(res)
|
||||
|
||||
@ -126,12 +145,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
:return: A proper design dictionary according to design parameter.
|
||||
"""
|
||||
if isinstance(design, dict):
|
||||
if 'design' not in design:
|
||||
raise ValueError('design dictionary must have \'design\' key in it')
|
||||
if "design" not in design:
|
||||
raise ValueError("design dictionary must have 'design' key in it")
|
||||
|
||||
return design
|
||||
|
||||
return {'design': design if design else ''}
|
||||
return {"design": design if design else ""}
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_design(design):
|
||||
@ -153,23 +172,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
:return: The design string according to design parameter.
|
||||
"""
|
||||
if not design:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if isinstance(design, six.string_types):
|
||||
return design
|
||||
|
||||
if isinstance(design, dict):
|
||||
if 'design' in design:
|
||||
return design['design']
|
||||
if "design" in design:
|
||||
return design["design"]
|
||||
|
||||
return list(design.values())[0]
|
||||
|
||||
raise ValueError('design must be a string or a dictionary with at least one value')
|
||||
raise ValueError(
|
||||
"design must be a string or a dictionary with at least one value"
|
||||
)
|
||||
|
||||
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
|
||||
upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None):
|
||||
""" Update model weights file and various model properties """
|
||||
def update(
|
||||
self,
|
||||
model_file=None,
|
||||
design=None,
|
||||
labels=None,
|
||||
name=None,
|
||||
comment=None,
|
||||
tags=None,
|
||||
task_id=None,
|
||||
project_id=None,
|
||||
parent_id=None,
|
||||
uri=None,
|
||||
framework=None,
|
||||
upload_storage_uri=None,
|
||||
target_filename=None,
|
||||
iteration=None,
|
||||
system_tags=None
|
||||
):
|
||||
"""Update model weights file and various model properties"""
|
||||
|
||||
if self.id is None:
|
||||
if upload_storage_uri:
|
||||
@ -182,7 +218,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||
|
||||
# upload model file if needed and get uri
|
||||
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
|
||||
uri = uri or (
|
||||
self._upload_model(model_file, target_filename=target_filename)
|
||||
if model_file
|
||||
else self.data.uri
|
||||
)
|
||||
# update fields
|
||||
design = self._wrap_design(design) if design else self.data.design
|
||||
name = name or self.data.name
|
||||
@ -192,7 +232,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
project = project_id or self.data.project
|
||||
parent = parent_id or self.data.parent
|
||||
tags = tags or self.data.tags
|
||||
if Session.check_min_api_version('2.3'):
|
||||
if Session.check_min_api_version("2.3"):
|
||||
system_tags = system_tags or self.data.system_tags
|
||||
|
||||
self._edit(
|
||||
@ -210,33 +250,74 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
def edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None, system_tags=None):
|
||||
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration, system_tags=system_tags)
|
||||
def edit(
|
||||
self,
|
||||
design=None,
|
||||
labels=None,
|
||||
name=None,
|
||||
comment=None,
|
||||
tags=None,
|
||||
uri=None,
|
||||
framework=None,
|
||||
iteration=None,
|
||||
system_tags=None
|
||||
):
|
||||
return self._edit(
|
||||
design=design,
|
||||
labels=labels,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
uri=uri,
|
||||
framework=framework,
|
||||
iteration=iteration,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
uri=None, framework=None, iteration=None, system_tags=None, **extra):
|
||||
def _edit(
|
||||
self,
|
||||
design=None,
|
||||
labels=None,
|
||||
name=None,
|
||||
comment=None,
|
||||
tags=None,
|
||||
uri=None,
|
||||
framework=None,
|
||||
iteration=None,
|
||||
system_tags=None,
|
||||
**extra
|
||||
):
|
||||
def offline_store(**kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self.data, k, v or getattr(self.data, k, None))
|
||||
return
|
||||
if self._offline_mode:
|
||||
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
|
||||
uri=uri, framework=framework, iteration=iteration, **extra)
|
||||
|
||||
if Session.check_min_api_version('2.3'):
|
||||
if self._offline_mode:
|
||||
return offline_store(
|
||||
design=design,
|
||||
labels=labels,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
uri=uri,
|
||||
framework=framework,
|
||||
iteration=iteration,
|
||||
**extra
|
||||
)
|
||||
|
||||
if Session.check_min_api_version("2.3"):
|
||||
if tags is not None:
|
||||
extra.update({'tags': tags})
|
||||
extra.update({"tags": tags})
|
||||
if system_tags is not None:
|
||||
extra.update({'system_tags': system_tags})
|
||||
extra.update({"system_tags": system_tags})
|
||||
elif tags is not None or system_tags is not None:
|
||||
if tags and system_tags:
|
||||
system_tags = system_tags[:]
|
||||
system_tags += [t for t in tags if t not in system_tags]
|
||||
extra.update({'system_tags': system_tags or tags or self.data.system_tags})
|
||||
extra.update({"system_tags": system_tags or tags or self.data.system_tags})
|
||||
|
||||
self.send(models.EditRequest(
|
||||
self.send(
|
||||
models.EditRequest(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
name=name,
|
||||
@ -246,23 +327,44 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
framework=framework,
|
||||
iteration=iteration,
|
||||
**extra
|
||||
))
|
||||
)
|
||||
)
|
||||
self.reload()
|
||||
|
||||
def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None,
|
||||
tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False,
|
||||
target_filename=None, cb=None, iteration=None):
|
||||
""" Update the given model for a given task ID """
|
||||
def update_and_upload(
|
||||
self,
|
||||
model_file,
|
||||
design=None,
|
||||
labels=None,
|
||||
name=None,
|
||||
comment=None,
|
||||
tags=None,
|
||||
task_id=None,
|
||||
project_id=None,
|
||||
parent_id=None,
|
||||
framework=None,
|
||||
async_enable=False,
|
||||
target_filename=None,
|
||||
cb=None,
|
||||
iteration=None
|
||||
):
|
||||
"""Update the given model for a given task ID"""
|
||||
if async_enable:
|
||||
|
||||
def callback(uploaded_uri):
|
||||
if uploaded_uri is None:
|
||||
return
|
||||
|
||||
# If not successful, mark model as failed_uploading
|
||||
if uploaded_uri is False:
|
||||
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
|
||||
uploaded_uri = "{}/failed_uploading".format(
|
||||
self._upload_storage_uri
|
||||
)
|
||||
|
||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri)
|
||||
Model._local_model_to_id_uri[str(model_file)] = (
|
||||
self.model_id,
|
||||
uploaded_uri,
|
||||
)
|
||||
|
||||
self.update(
|
||||
uri=uploaded_uri,
|
||||
@ -281,11 +383,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
if cb:
|
||||
cb(model_file)
|
||||
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename,
|
||||
cb=callback)
|
||||
uri = self._upload_model(
|
||||
model_file,
|
||||
async_enable=async_enable,
|
||||
target_filename=target_filename,
|
||||
cb=callback,
|
||||
)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||
uri = self._upload_model(
|
||||
model_file, async_enable=async_enable, target_filename=target_filename
|
||||
)
|
||||
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
|
||||
self.update(
|
||||
uri=uri,
|
||||
@ -302,7 +410,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
return uri
|
||||
|
||||
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None):
|
||||
def update_for_task(
|
||||
self, task_id, name=None, model_id=None, type_="output", iteration=None
|
||||
):
|
||||
if Session.check_min_api_version("2.13"):
|
||||
req = tasks.AddOrUpdateModelRequest(
|
||||
task=task_id, name=name, type=type_, model=model_id, iteration=iteration
|
||||
@ -314,7 +424,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
# backwards compatibility, None
|
||||
req = None
|
||||
else:
|
||||
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_))
|
||||
raise ValueError(
|
||||
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
|
||||
)
|
||||
|
||||
if req:
|
||||
self.send(req)
|
||||
@ -323,7 +435,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
@property
|
||||
def model_design(self):
|
||||
""" Get the model design. For now, this is stored as a single key in the design dict. """
|
||||
"""Get the model design. For now, this is stored as a single key in the design dict."""
|
||||
try:
|
||||
return self._unwrap_design(self.data.design)
|
||||
except ValueError:
|
||||
@ -364,7 +476,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags
|
||||
return (
|
||||
self.data.system_tags
|
||||
if hasattr(self.data, "system_tags")
|
||||
else self.data.tags
|
||||
)
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
@ -394,8 +510,10 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
:param bool raise_on_error: If True and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
|
||||
:return: a local path to a downloaded copy of the model
|
||||
"""
|
||||
uri = self.data.uri
|
||||
@ -403,21 +521,29 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return None
|
||||
|
||||
# check if we already downloaded the file
|
||||
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri]
|
||||
downloaded_models = [
|
||||
k
|
||||
for k, (i, u) in Model._local_model_to_id_uri.items()
|
||||
if i == self.id and u == uri
|
||||
]
|
||||
for dl_file in downloaded_models:
|
||||
if Path(dl_file).exists() and not force_download:
|
||||
return dl_file
|
||||
# remove non existing model file
|
||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||
|
||||
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download)
|
||||
local_download = StorageManager.get_local_copy(
|
||||
uri, extract_archive=False, force_download=force_download
|
||||
)
|
||||
|
||||
# save local model, so we can later query what was the original one
|
||||
if local_download is not None:
|
||||
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
|
||||
elif raise_on_error:
|
||||
raise ValueError("Could not retrieve a local copy of model weights {}, "
|
||||
"failed downloading {}".format(self.model_id, uri))
|
||||
raise ValueError(
|
||||
"Could not retrieve a local copy of model weights {}, "
|
||||
"failed downloading {}".format(self.model_id, uri)
|
||||
)
|
||||
|
||||
return local_download
|
||||
|
||||
@ -426,9 +552,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return self._cache_dir
|
||||
|
||||
def save_model_design_file(self):
|
||||
""" Download model description file into a local file in our cache_dir """
|
||||
"""Download model description file into a local file in our cache_dir"""
|
||||
design = self.model_design
|
||||
filename = self.data.name + '.txt'
|
||||
filename = self.data.name + ".txt"
|
||||
p = Path(self.cache_dir) / filename
|
||||
# we always write the original model design to file, to prevent any mishaps
|
||||
# if p.is_file():
|
||||
@ -438,11 +564,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return str(p)
|
||||
|
||||
def get_model_package(self):
|
||||
""" Get a named tuple containing the model's weights and design """
|
||||
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
|
||||
"""Get a named tuple containing the model's weights and design"""
|
||||
return ModelPackage(
|
||||
weights=self.download_model_weights(), design=self.save_model_design_file()
|
||||
)
|
||||
|
||||
def get_model_design(self):
|
||||
""" Get model description (text) """
|
||||
"""Get model description (text)"""
|
||||
return self.model_design
|
||||
|
||||
@classmethod
|
||||
@ -465,8 +593,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
data = self.data
|
||||
assert isinstance(data, models.Model)
|
||||
parent = self.id if child else None
|
||||
extra = {'system_tags': tags or data.system_tags} \
|
||||
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags}
|
||||
extra = (
|
||||
{"system_tags": tags or data.system_tags}
|
||||
if Session.check_min_api_version("2.3")
|
||||
else {"tags": tags or data.tags}
|
||||
)
|
||||
req = models.CreateRequest(
|
||||
uri=data.uri,
|
||||
name=name,
|
||||
@ -485,8 +616,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
|
||||
def _create_empty_model(self, upload_storage_uri=None, project_id=None):
|
||||
upload_storage_uri = upload_storage_uri or self.upload_storage_uri
|
||||
name = make_message('Anonymous model %(time)s')
|
||||
uri = '{}/uploading_file'.format(upload_storage_uri or 'file://')
|
||||
name = make_message("Anonymous model %(time)s")
|
||||
uri = "{}/uploading_file".format(upload_storage_uri or "file://")
|
||||
req = models.CreateRequest(uri=uri, name=name, labels={}, project=project_id)
|
||||
res = self.send(req)
|
||||
if not res:
|
||||
|
@ -652,7 +652,6 @@ if __name__ == '__main__':
|
||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||
a_function, sanitize_function=_sanitize_function
|
||||
)
|
||||
|
||||
# add helper functions on top.
|
||||
for f in (helper_functions or []):
|
||||
helper_function_source, _ = CreateFromFunction.__extract_function_information(
|
||||
@ -665,7 +664,6 @@ if __name__ == '__main__':
|
||||
if artifact_serialization_function
|
||||
else ("", "None")
|
||||
)
|
||||
|
||||
artifact_deserialization_function_source, artifact_deserialization_function_name = (
|
||||
CreateFromFunction.__extract_function_information(artifact_deserialization_function)
|
||||
if artifact_deserialization_function
|
||||
@ -833,7 +831,5 @@ if __name__ == '__main__':
|
||||
function_source = inspect.getsource(function)
|
||||
if sanitize_function:
|
||||
function_source = sanitize_function(function_source)
|
||||
|
||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
||||
|
||||
return function_source, function_name
|
@ -255,10 +255,14 @@ class ScriptRequirements(object):
|
||||
|
||||
@staticmethod
|
||||
def _remove_package_versions(installed_pkgs, package_names_to_remove_version):
|
||||
installed_pkgs = {k: (v[0], None if str(k) in package_names_to_remove_version else v[1])
|
||||
for k, v in installed_pkgs.items()}
|
||||
def _internal(_installed_pkgs):
|
||||
return {
|
||||
k: (v[0], None if str(k) in package_names_to_remove_version else v[1])
|
||||
if not isinstance(v, dict) else _internal(v)
|
||||
for k, v in _installed_pkgs.items()
|
||||
}
|
||||
|
||||
return installed_pkgs
|
||||
return _internal(installed_pkgs)
|
||||
|
||||
|
||||
class _JupyterObserver(object):
|
||||
@ -781,6 +785,7 @@ class ScriptInfo(object):
|
||||
try:
|
||||
# we expect to find boto3 in the sagemaker env
|
||||
import boto3
|
||||
|
||||
with open(cls._sagemaker_metadata_path) as f:
|
||||
notebook_data = json.load(f)
|
||||
client = boto3.client("sagemaker")
|
||||
@ -799,7 +804,6 @@ class ScriptInfo(object):
|
||||
return jupyter_session.get("path", ""), jupyter_session.get("name", "")
|
||||
except Exception as e:
|
||||
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
|
@ -700,7 +700,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
|
||||
the text will not be printed, because the Python process is immediately terminated.
|
||||
|
||||
:param bool ignore_errors: If True default), ignore any errors raised
|
||||
:param bool ignore_errors: If True (default), ignore any errors raised
|
||||
:param bool force: If True, the task status will be changed to `stopped` regardless of the current Task state.
|
||||
:param str status_message: Optional, add status change message to the stop request.
|
||||
This message will be stored as status_message on the Task's info panel
|
||||
@ -718,11 +718,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
tasks.CompletedRequest(
|
||||
self.id, status_reason='completed', status_message=status_message, force=force),
|
||||
ignore_errors=ignore_errors)
|
||||
|
||||
if self._get_runtime_properties().get("_publish_on_complete"):
|
||||
self.send(
|
||||
tasks.PublishRequest(
|
||||
self.id, status_reason='completed', status_message=status_message, force=force),
|
||||
ignore_errors=ignore_errors)
|
||||
|
||||
return resp
|
||||
return self.send(
|
||||
tasks.StoppedRequest(self.id, status_reason='completed', status_message=status_message, force=force),
|
||||
@ -2387,7 +2389,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
return True
|
||||
|
||||
def _get_runtime_properties(self):
|
||||
# type: () -> Mapping[str, str]
|
||||
# type: () -> Dict[str, str]
|
||||
if not Session.check_min_api_version('2.13'):
|
||||
return dict()
|
||||
return dict(**self.data.runtime) if self.data.runtime else dict()
|
||||
|
@ -93,7 +93,7 @@ def main():
|
||||
# Take the credentials in raw form or from api section
|
||||
credentials = get_parsed_field(parsed, ["credentials"])
|
||||
api_server = get_parsed_field(parsed, ["api_server", "host"])
|
||||
web_server = get_parsed_field(parsed, ["web_server"])
|
||||
web_server = get_parsed_field(parsed, ["web_server"]) # TODO: if previous fails, this will fail too
|
||||
files_server = get_parsed_field(parsed, ["files_server"])
|
||||
except Exception:
|
||||
credentials = credentials or None
|
||||
|
@ -17,7 +17,6 @@ from attr import attrs, attrib
|
||||
from pathlib2 import Path
|
||||
|
||||
from .. import Task, StorageManager, Logger
|
||||
from ..backend_api.session.client import APIClient
|
||||
from ..backend_api import Session
|
||||
from ..backend_interface.task.development.worker import DevWorker
|
||||
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project
|
||||
|
@ -359,6 +359,7 @@ class Logger(object):
|
||||
iteration=0,
|
||||
table_plot=df,
|
||||
extra_data={'columnwidth': [2., 1., 1., 1.]})
|
||||
|
||||
"""
|
||||
mutually_exclusive(
|
||||
UsageError, _check_none=True,
|
||||
|
555
clearml/model.py
555
clearml/model.py
File diff suppressed because it is too large
Load Diff
@ -1177,8 +1177,7 @@ class StorageHelper(object):
|
||||
|
||||
def _do_async_upload(self, data):
|
||||
assert isinstance(data, self._UploadData)
|
||||
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback,
|
||||
verbose=True, retries=data.retries, return_canonized=data.return_canonized)
|
||||
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, verbose=True, retries=data.retries, return_canonized=data.return_canonized)
|
||||
|
||||
def _upload_from_file(self, local_path, dest_path, extra=None):
|
||||
if not hasattr(self._driver, 'upload_object'):
|
||||
@ -1473,7 +1472,6 @@ class _HttpDriver(_Driver):
|
||||
try:
|
||||
container = self.get_container(container_name)
|
||||
url = container_name + object_name
|
||||
|
||||
return container.session.head(url, allow_redirects=True, headers=container.get_headers(url)).ok
|
||||
except Exception:
|
||||
return False
|
||||
|
@ -1,5 +1,4 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
from multiprocessing.pool import ThreadPool
|
||||
@ -7,7 +6,6 @@ from random import random
|
||||
from time import time
|
||||
from typing import List, Optional, Union
|
||||
from zipfile import ZipFile
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
@ -304,8 +302,8 @@ class StorageManager(object):
|
||||
if not local_folder:
|
||||
local_folder = CacheManager.get_cache_manager().get_cache_folder()
|
||||
local_path = str(Path(local_folder).expanduser().absolute() / bucket_path)
|
||||
|
||||
helper = StorageHelper.get(remote_url)
|
||||
|
||||
return helper.download_to_file(
|
||||
remote_url,
|
||||
local_path,
|
||||
|
@ -1988,7 +1988,9 @@ class Task(_Task):
|
||||
corresponding to debug sample's file name in the UI, also known as variant
|
||||
:param int n_last_iterations: How many debug samples iterations to fetch in reverse chronological order.
|
||||
Leave empty to get all debug samples.
|
||||
|
||||
:raise: TypeError if `n_last_iterations` is explicitly set to anything other than a positive integer value
|
||||
|
||||
:return: A list of `dict`s, each dictionary containing the debug sample's URL and other metadata.
|
||||
The URLs can be passed to :meth:`StorageManager.get_local_copy` to fetch local copies of debug samples.
|
||||
"""
|
||||
@ -2021,11 +2023,14 @@ class Task(_Task):
|
||||
|
||||
def _get_debug_samples(self, title, series, n_last_iterations=None):
|
||||
response = self._send_debug_image_request(title, series, n_last_iterations)
|
||||
|
||||
debug_samples = []
|
||||
|
||||
while True:
|
||||
scroll_id = response.response.scroll_id
|
||||
for metric_resp in response.response.metrics:
|
||||
iterations_events = [iteration["events"] for iteration in metric_resp.iterations] # type: List[List[dict]]
|
||||
scroll_id = response.response_data.get("scroll_id", None)
|
||||
|
||||
for metric_resp in response.response_data.get("metrics", []):
|
||||
iterations_events = [iteration["events"] for iteration in metric_resp.get("iterations", [])] # type: List[List[dict]]
|
||||
flattened_events = (event
|
||||
for single_iter_events in iterations_events
|
||||
for event in single_iter_events)
|
||||
@ -2037,8 +2042,8 @@ class Task(_Task):
|
||||
|
||||
if (len(debug_samples) == n_last_iterations
|
||||
or all(
|
||||
len(metric_resp.iterations) == 0
|
||||
for metric_resp in response.response.metrics)):
|
||||
len(metric_resp.get("iterations", [])) == 0
|
||||
for metric_resp in response.response_data.get("metrics", []))):
|
||||
break
|
||||
|
||||
return debug_samples
|
||||
@ -2877,13 +2882,11 @@ class Task(_Task):
|
||||
Set offline mode, where all data and logs are stored into local folder, for later transmission
|
||||
|
||||
.. note::
|
||||
|
||||
`Task.set_offline` can't move the same task from offline to online, nor can it be applied before `Task.create`.
|
||||
See below an example of **incorect** usage of `Task.set_offline`:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
from clearml import Task
|
||||
|
||||
Task.set_offline(True)
|
||||
task = Task.create(project_name='DEBUG', task_name="offline")
|
||||
# ^^^ an error or warning is emitted, telling us that `Task.set_offline(True)`
|
||||
@ -2891,23 +2894,25 @@ class Task(_Task):
|
||||
Task.set_offline(False)
|
||||
# ^^^ an error or warning is emitted, telling us that running `Task.set_offline(False)`
|
||||
# while the current task is not closed is not something we support
|
||||
|
||||
data = task.export_task()
|
||||
|
||||
imported_task = Task.import_task(task_data=data)
|
||||
|
||||
The correct way to use `Task.set_offline` can be seen in the following example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
from clearml import Task
|
||||
|
||||
Task.set_offline(True)
|
||||
task = Task.init(project_name='DEBUG', task_name="offline")
|
||||
task.upload_artifact("large_artifact", "test_strign")
|
||||
task.close()
|
||||
Task.set_offline(False)
|
||||
|
||||
imported_task = Task.import_offline_session(task.get_offline_mode_folder())
|
||||
|
||||
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
|
||||
|
||||
:return:
|
||||
"""
|
||||
if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode:
|
||||
@ -2932,6 +2937,7 @@ class Task(_Task):
|
||||
# type: () -> bool
|
||||
"""
|
||||
Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
|
||||
|
||||
:return: boolean offline-mode state
|
||||
"""
|
||||
return cls._offline_mode
|
||||
@ -3542,11 +3548,9 @@ class Task(_Task):
|
||||
def _check_keys(dict_, warning_sent=False):
|
||||
if warning_sent:
|
||||
return
|
||||
|
||||
for k, v in dict_.items():
|
||||
if warning_sent:
|
||||
return
|
||||
|
||||
if not isinstance(k, str):
|
||||
getLogger().warning(
|
||||
"Unsupported key of type '{}' found when connecting dictionary. It will be converted to str".format(
|
||||
@ -3554,12 +3558,10 @@ class Task(_Task):
|
||||
)
|
||||
)
|
||||
warning_sent = True
|
||||
|
||||
if isinstance(v, dict):
|
||||
_check_keys(v, warning_sent)
|
||||
|
||||
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
|
||||
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
|
||||
_check_keys(dictionary)
|
||||
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
|
||||
self._arguments.copy_from_dict(flat_dict, prefix=name)
|
||||
@ -3909,7 +3911,6 @@ class Task(_Task):
|
||||
try:
|
||||
# make sure the state of the offline data is saved
|
||||
self._edit()
|
||||
|
||||
# create zip file
|
||||
offline_folder = self.get_offline_mode_folder()
|
||||
zip_file = offline_folder.as_posix() + '.zip'
|
||||
|
@ -223,11 +223,18 @@
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
"name": "python",
|
||||
"version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "8b483fbf9fa60c6c6195634afd5159f586a30c5c6a9d31fa17f93a17f02fdc40"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -451,4 +451,5 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ from enum import Enum
|
||||
from clearml import Task
|
||||
from clearml.task_parameters import TaskParameters, param, percent_param
|
||||
|
||||
|
||||
# Connecting ClearML with the current process,
|
||||
# from here on everything is logged automatically
|
||||
task = Task.init(project_name='FirstTrial', task_name='first_trial')
|
||||
@ -43,7 +44,6 @@ class IntEnumClass(Enum):
|
||||
C = 1
|
||||
D = 2
|
||||
|
||||
|
||||
parameters = {
|
||||
'list': [1, 2, 3],
|
||||
'dict': {'a': 1, 'b': 2},
|
Loading…
Reference in New Issue
Block a user