Refactor code

This commit is contained in:
allegroai 2023-03-27 13:38:11 +03:00
parent ecf6a4df2a
commit 4cd8857c0d
23 changed files with 1116 additions and 781 deletions

2
.gitignore vendored
View File

@ -11,8 +11,8 @@ build/
dist/ dist/
*.egg-info *.egg-info
.env .env
venv/
.venv/ .venv/
venv/
# example data # example data
examples/runs/ examples/runs/

View File

@ -218,7 +218,6 @@ class PipelineController(object):
def serialize(obj): def serialize(obj):
import dill import dill
return dill.dumps(obj) return dill.dumps(obj)
:param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`, :param artifact_deserialization_function: A deserialization function that takes one parameter of type `bytes`,
which represents the serialized object. This function should return the deserialized object. which represents the serialized object. This function should return the deserialized object.
All parameter/return artifacts fetched by the pipeline will be deserialized using this function. All parameter/return artifacts fetched by the pipeline will be deserialized using this function.
@ -1157,6 +1156,7 @@ class PipelineController(object):
# type: (bool, str) -> bool # type: (bool, str) -> bool
""" """
Evaluate whether or not the pipeline is successful Evaluate whether or not the pipeline is successful
:param fail_on_step_fail: If True (default), evaluate the pipeline steps' status to assess if the pipeline :param fail_on_step_fail: If True (default), evaluate the pipeline steps' status to assess if the pipeline
is successful. If False, only evaluate the controller is successful. If False, only evaluate the controller
:param fail_condition: Must be one of the following: 'all' (default), 'failed' or 'aborted'. If 'failed', this :param fail_condition: Must be one of the following: 'all' (default), 'failed' or 'aborted'. If 'failed', this
@ -1175,18 +1175,14 @@ class PipelineController(object):
success_status = [Task.TaskStatusEnum.completed, Task.TaskStatusEnum.failed] success_status = [Task.TaskStatusEnum.completed, Task.TaskStatusEnum.failed]
else: else:
raise UsageError("fail_condition needs to be one of the following: 'all', 'failed', 'aborted'") raise UsageError("fail_condition needs to be one of the following: 'all', 'failed', 'aborted'")
if self._task.status not in success_status: if self._task.status not in success_status:
return False return False
if not fail_on_step_fail: if not fail_on_step_fail:
return True return True
self._update_nodes_status() self._update_nodes_status()
for node in self._nodes.values(): for node in self._nodes.values():
if node.status not in success_status: if node.status not in success_status:
return False return False
return True return True
def elapsed(self): def elapsed(self):

View File

@ -139,7 +139,7 @@ class TaskTrigger(BaseTrigger):
raise ValueError("You must provide metric/variant/threshold") raise ValueError("You must provide metric/variant/threshold")
valid_status = [str(s) for s in Task.TaskStatusEnum] valid_status = [str(s) for s in Task.TaskStatusEnum]
if self.on_status and not all(s in valid_status for s in self.on_status): if self.on_status and not all(s in valid_status for s in self.on_status):
raise ValueError("You on_status contains invalid status value: {}".format(self.on_status)) raise ValueError("Your on_status contains invalid status value: {}".format(self.on_status))
valid_signs = ['min', 'minimum', 'max', 'maximum'] valid_signs = ['min', 'minimum', 'max', 'maximum']
if self.value_sign and self.value_sign not in valid_signs: if self.value_sign and self.value_sign not in valid_signs:
raise ValueError("Invalid value_sign `{}`, valid options are: {}".format(self.value_sign, valid_signs)) raise ValueError("Invalid value_sign `{}`, valid options are: {}".format(self.value_sign, valid_signs))

View File

@ -2,7 +2,6 @@
auth service auth service
This service provides authentication management and authorization This service provides authentication management and authorization
validation for the entire system. validation for the entire system.
""" """
import six import six

View File

@ -40,8 +40,10 @@ for a very long time for a non-responding or mis-configured server
""" """
ENV_API_EXTRA_RETRY_CODES = EnvEntry("CLEARML_API_EXTRA_RETRY_CODES") ENV_API_EXTRA_RETRY_CODES = EnvEntry("CLEARML_API_EXTRA_RETRY_CODES")
ENV_FORCE_MAX_API_VERSION = EnvEntry("CLEARML_FORCE_MAX_API_VERSION", type=str) ENV_FORCE_MAX_API_VERSION = EnvEntry("CLEARML_FORCE_MAX_API_VERSION", type=str)
class MissingConfigError(ValueError): class MissingConfigError(ValueError):
def __init__(self, message=None): def __init__(self, message=None):
if message is None: if message is None:

View File

@ -94,7 +94,6 @@ class CompoundRequest(Request):
if self._item_prop_name in dict_properties: if self._item_prop_name in dict_properties:
del dict_properties[self._item_prop_name] del dict_properties[self._item_prop_name]
dict_.update(dict_properties) dict_.update(dict_properties)
return dict_ return dict_
def validate(self): def validate(self):

View File

@ -134,6 +134,7 @@ class Session(TokenManager):
**kwargs **kwargs
): ):
self.__class__._sessions_weakrefs.append(weakref.ref(self)) self.__class__._sessions_weakrefs.append(weakref.ref(self))
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get() self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
self._logger = logger self._logger = logger
if self._verbose and not self._logger: if self._verbose and not self._logger:
@ -147,7 +148,6 @@ class Session(TokenManager):
self.__init_host = host self.__init_host = host
self.__init_http_retries_config = http_retries_config self.__init_http_retries_config = http_retries_config
self.__token_manager_kwargs = kwargs self.__token_manager_kwargs = kwargs
if config is not None: if config is not None:
self.config = config self.config = config
else: else:
@ -162,21 +162,21 @@ class Session(TokenManager):
self._ssl_error_count_verbosity = self.config.get( self._ssl_error_count_verbosity = self.config.get(
"api.ssl_error_count_verbosity", self._ssl_error_count_verbosity) "api.ssl_error_count_verbosity", self._ssl_error_count_verbosity)
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
self.__host = self.__init_host or self.get_api_server_host(config=self.config)
if not self.__host: if not self.__host:
raise ValueError("ClearML host was not set, check your configuration file or environment variable") raise ValueError("ClearML host was not set, check your configuration file or environment variable")
self.__host = self.__host.strip("/") self.__host = self.__host.strip("/")
self.__http_retries_config = self.__init_http_retries_config or self.config.get( self.__http_retries_config = self.__init_http_retries_config or self.config.get(
"api.http.retries", ConfigTree()).as_plain_ordered_dict() "api.http.retries", ConfigTree()).as_plain_ordered_dict()
self.__http_retries_config["status_forcelist"] = self._get_retry_codes() self.__http_retries_config["status_forcelist"] = self._get_retry_codes()
self.__http_retries_config["config"] = self.config self.__http_retries_config["config"] = self.config
self.__http_session = get_http_session_with_retry(**self.__http_retries_config) self.__http_session = get_http_session_with_retry(**self.__http_retries_config)
self.__http_session.write_timeout = self._write_session_timeout self.__http_session.write_timeout = self._write_session_timeout
self.__http_session.request_size_threshold = self._write_session_data_size self.__http_session.request_size_threshold = self._write_session_data_size
self.__max_req_size = self.config.get("api.http.max_req_size", None)
self.__max_req_size = self.config.get("api.http.max_req_size", None)
if not self.__max_req_size: if not self.__max_req_size:
raise ValueError("missing max request size") raise ValueError("missing max request size")
@ -186,7 +186,6 @@ class Session(TokenManager):
req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None) req_token_expiration_sec = self.config.get("api.auth.req_token_expiration_sec", None)
self.__auth_token = None self.__auth_token = None
self._update_default_api_method() self._update_default_api_method()
if ENV_AUTH_TOKEN.get(): if ENV_AUTH_TOKEN.get():
self.__access_key = self.__secret_key = None self.__access_key = self.__secret_key = None
self.__auth_token = ENV_AUTH_TOKEN.get() self.__auth_token = ENV_AUTH_TOKEN.get()
@ -203,9 +202,11 @@ class Session(TokenManager):
if not self.secret_key and not self.access_key and not self.__auth_token: if not self.secret_key and not self.access_key and not self.__auth_token:
raise MissingConfigError() raise MissingConfigError()
super(Session, self).__init__( super(Session, self).__init__(
**self.__token_manager_kwargs, **self.__token_manager_kwargs,
req_token_expiration_sec=req_token_expiration_sec, token_expiration_threshold_sec=token_expiration_threshold_sec,
req_token_expiration_sec=req_token_expiration_sec
) )
self.refresh_token() self.refresh_token()
@ -633,6 +634,7 @@ class Session(TokenManager):
return call_result return call_result
@classmethod
def _make_all_sessions_go_online(cls): def _make_all_sessions_go_online(cls):
for active_session in cls._get_all_active_sessions(): for active_session in cls._get_all_active_sessions():
# noinspection PyProtectedMember # noinspection PyProtectedMember
@ -647,7 +649,6 @@ class Session(TokenManager):
if session: if session:
active_sessions.append(session) active_sessions.append(session)
new_sessions_weakrefs.append(session_weakref) new_sessions_weakrefs.append(session_weakref)
cls._sessions_weakrefs = session_weakref cls._sessions_weakrefs = session_weakref
return active_sessions return active_sessions

View File

@ -7,6 +7,7 @@ from time import time
from pathlib2 import Path from pathlib2 import Path
from ...backend_api import Session
from ...backend_api.services import events as api_events from ...backend_api.services import events as api_events
from ..base import InterfaceBase from ..base import InterfaceBase
from ...config import config, deferred_config from ...config import config, deferred_config

View File

@ -271,9 +271,13 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
self._for_model = for_model self._for_model = for_model
flush_threshold = config.get("development.worker.report_event_flush_threshold", 100) flush_threshold = config.get("development.worker.report_event_flush_threshold", 100)
self._report_service = BackgroundReportService( self._report_service = BackgroundReportService(
task=task, async_enable=async_enable, metrics=metrics, task=task,
async_enable=async_enable,
metrics=metrics,
flush_frequency=self._flush_frequency, flush_frequency=self._flush_frequency,
flush_threshold=flush_threshold, for_model=for_model) flush_threshold=flush_threshold,
for_model=for_model,
)
self._report_service.start() self._report_service.start()
def _set_storage_uri(self, value): def _set_storage_uri(self, value):
@ -355,8 +359,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param iter: Iteration number :param iter: Iteration number
:type iter: int :type iter: int
""" """
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value, ev = ScalarEvent(
iter=iter) metric=self._normalize_name(title),
variant=self._normalize_name(series),
value=value,
iter=iter
)
self._report(ev) self._report(ev)
def report_vector(self, title, series, values, iter): def report_vector(self, title, series, values, iter):
@ -457,8 +465,12 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
elif not isinstance(plot, six.string_types): elif not isinstance(plot, six.string_types):
raise ValueError('Plot should be a string or a dict') raise ValueError('Plot should be a string or a dict')
ev = PlotEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), ev = PlotEvent(
plot_str=plot, iter=iter) metric=self._normalize_name(title),
variant=self._normalize_name(series),
plot_str=plot,
iter=iter
)
self._report(ev) self._report(ev)
def report_image(self, title, series, src, iter): def report_image(self, title, series, src, iter):

View File

@ -12,7 +12,7 @@ from ..storage import StorageManager
from ..storage.helper import StorageHelper from ..storage.helper import StorageHelper
from ..utilities.async_manager import AsyncManagerMixin from ..utilities.async_manager import AsyncManagerMixin
ModelPackage = namedtuple('ModelPackage', 'weights design') ModelPackage = namedtuple("ModelPackage", "weights design")
class ModelDoesNotExistError(Exception): class ModelDoesNotExistError(Exception):
@ -22,12 +22,12 @@ class ModelDoesNotExistError(Exception):
class _StorageUriMixin(object): class _StorageUriMixin(object):
@property @property
def upload_storage_uri(self): def upload_storage_uri(self):
""" A URI into which models are uploaded """ """A URI into which models are uploaded"""
return self._upload_storage_uri return self._upload_storage_uri
@upload_storage_uri.setter @upload_storage_uri.setter
def upload_storage_uri(self, value): def upload_storage_uri(self, value):
self._upload_storage_uri = value.rstrip('/') if value else None self._upload_storage_uri = value.rstrip("/") if value else None
def create_dummy_model(upload_storage_uri=None, *args, **kwargs): def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
@ -44,9 +44,9 @@ def create_dummy_model(upload_storage_uri=None, *args, **kwargs):
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin): class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
""" Manager for backend model objects """ """Manager for backend model objects"""
_EMPTY_MODEL_ID = 'empty' _EMPTY_MODEL_ID = "empty"
_local_model_to_id_uri = {} _local_model_to_id_uri = {}
@ -54,8 +54,15 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def model_id(self): def model_id(self):
return self.id return self.id
def __init__(self, upload_storage_uri, cache_dir, model_id=None, def __init__(
upload_storage_suffix='models', session=None, log=None): self,
upload_storage_uri,
cache_dir,
model_id=None,
upload_storage_suffix="models",
session=None,
log=None
):
super(Model, self).__init__(id=model_id, session=session, log=log) super(Model, self).__init__(id=model_id, session=session, log=log)
self._upload_storage_suffix = upload_storage_suffix self._upload_storage_suffix = upload_storage_suffix
if model_id == self._EMPTY_MODEL_ID: if model_id == self._EMPTY_MODEL_ID:
@ -71,7 +78,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
self.reload() self.reload()
def _reload(self): def _reload(self):
""" Reload the model object """ """Reload the model object"""
if self._offline_mode: if self._offline_mode:
return models.Model() return models.Model()
@ -80,11 +87,19 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
res = self.send(models.GetByIdRequest(model=self.id)) res = self.send(models.GetByIdRequest(model=self.id))
return res.response.model return res.response.model
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None): def _upload_model(
self, model_file, async_enable=False, target_filename=None, cb=None
):
if not self.upload_storage_uri: if not self.upload_storage_uri:
raise ValueError('Model has no storage URI defined (nowhere to upload to)') raise ValueError("Model has no storage URI defined (nowhere to upload to)")
target_filename = target_filename or Path(model_file).name target_filename = target_filename or Path(model_file).name
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename)) dest_path = "/".join(
(
self.upload_storage_uri,
self._upload_storage_suffix or ".",
target_filename,
)
)
result = StorageHelper.get(dest_path).upload( result = StorageHelper.get(dest_path).upload(
src_path=model_file, src_path=model_file,
dest_path=dest_path, dest_path=dest_path,
@ -93,19 +108,23 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return_canonized=False return_canonized=False
) )
if async_enable: if async_enable:
def msg(num_results): def msg(num_results):
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path)) self.log.info(
"Waiting for previous model to upload (%d pending, %s)"
% (num_results, dest_path)
)
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg) self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
return dest_path return dest_path
def _upload_callback(self, res, cb=None): def _upload_callback(self, res, cb=None):
if res is None: if res is None:
self.log.debug('Starting model upload') self.log.debug("Starting model upload")
elif res is False: elif res is False:
self.log.info('Failed model upload') self.log.info("Failed model upload")
else: else:
self.log.info('Completed model upload to {}'.format(res)) self.log.info("Completed model upload to {}".format(res))
if cb: if cb:
cb(res) cb(res)
@ -126,12 +145,12 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:return: A proper design dictionary according to design parameter. :return: A proper design dictionary according to design parameter.
""" """
if isinstance(design, dict): if isinstance(design, dict):
if 'design' not in design: if "design" not in design:
raise ValueError('design dictionary must have \'design\' key in it') raise ValueError("design dictionary must have 'design' key in it")
return design return design
return {'design': design if design else ''} return {"design": design if design else ""}
@staticmethod @staticmethod
def _unwrap_design(design): def _unwrap_design(design):
@ -153,23 +172,40 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:return: The design string according to design parameter. :return: The design string according to design parameter.
""" """
if not design: if not design:
return '' return ""
if isinstance(design, six.string_types): if isinstance(design, six.string_types):
return design return design
if isinstance(design, dict): if isinstance(design, dict):
if 'design' in design: if "design" in design:
return design['design'] return design["design"]
return list(design.values())[0] return list(design.values())[0]
raise ValueError('design must be a string or a dictionary with at least one value') raise ValueError(
"design must be a string or a dictionary with at least one value"
)
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None, def update(
task_id=None, project_id=None, parent_id=None, uri=None, framework=None, self,
upload_storage_uri=None, target_filename=None, iteration=None, system_tags=None): model_file=None,
""" Update model weights file and various model properties """ design=None,
labels=None,
name=None,
comment=None,
tags=None,
task_id=None,
project_id=None,
parent_id=None,
uri=None,
framework=None,
upload_storage_uri=None,
target_filename=None,
iteration=None,
system_tags=None
):
"""Update model weights file and various model properties"""
if self.id is None: if self.id is None:
if upload_storage_uri: if upload_storage_uri:
@ -182,7 +218,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
# upload model file if needed and get uri # upload model file if needed and get uri
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri) uri = uri or (
self._upload_model(model_file, target_filename=target_filename)
if model_file
else self.data.uri
)
# update fields # update fields
design = self._wrap_design(design) if design else self.data.design design = self._wrap_design(design) if design else self.data.design
name = name or self.data.name name = name or self.data.name
@ -192,7 +232,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
project = project_id or self.data.project project = project_id or self.data.project
parent = parent_id or self.data.parent parent = parent_id or self.data.parent
tags = tags or self.data.tags tags = tags or self.data.tags
if Session.check_min_api_version('2.3'): if Session.check_min_api_version("2.3"):
system_tags = system_tags or self.data.system_tags system_tags = system_tags or self.data.system_tags
self._edit( self._edit(
@ -210,33 +250,74 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
system_tags=system_tags, system_tags=system_tags,
) )
def edit(self, design=None, labels=None, name=None, comment=None, tags=None, def edit(
uri=None, framework=None, iteration=None, system_tags=None): self,
return self._edit(design=design, labels=labels, name=name, comment=comment, tags=tags, design=None,
uri=uri, framework=framework, iteration=iteration, system_tags=system_tags) labels=None,
name=None,
comment=None,
tags=None,
uri=None,
framework=None,
iteration=None,
system_tags=None
):
return self._edit(
design=design,
labels=labels,
name=name,
comment=comment,
tags=tags,
uri=uri,
framework=framework,
iteration=iteration,
system_tags=system_tags,
)
def _edit(self, design=None, labels=None, name=None, comment=None, tags=None, def _edit(
uri=None, framework=None, iteration=None, system_tags=None, **extra): self,
design=None,
labels=None,
name=None,
comment=None,
tags=None,
uri=None,
framework=None,
iteration=None,
system_tags=None,
**extra
):
def offline_store(**kwargs): def offline_store(**kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self.data, k, v or getattr(self.data, k, None)) setattr(self.data, k, v or getattr(self.data, k, None))
return return
if self._offline_mode:
return offline_store(design=design, labels=labels, name=name, comment=comment, tags=tags,
uri=uri, framework=framework, iteration=iteration, **extra)
if Session.check_min_api_version('2.3'): if self._offline_mode:
return offline_store(
design=design,
labels=labels,
name=name,
comment=comment,
tags=tags,
uri=uri,
framework=framework,
iteration=iteration,
**extra
)
if Session.check_min_api_version("2.3"):
if tags is not None: if tags is not None:
extra.update({'tags': tags}) extra.update({"tags": tags})
if system_tags is not None: if system_tags is not None:
extra.update({'system_tags': system_tags}) extra.update({"system_tags": system_tags})
elif tags is not None or system_tags is not None: elif tags is not None or system_tags is not None:
if tags and system_tags: if tags and system_tags:
system_tags = system_tags[:] system_tags = system_tags[:]
system_tags += [t for t in tags if t not in system_tags] system_tags += [t for t in tags if t not in system_tags]
extra.update({'system_tags': system_tags or tags or self.data.system_tags}) extra.update({"system_tags": system_tags or tags or self.data.system_tags})
self.send(models.EditRequest( self.send(
models.EditRequest(
model=self.id, model=self.id,
uri=uri, uri=uri,
name=name, name=name,
@ -246,23 +327,44 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
framework=framework, framework=framework,
iteration=iteration, iteration=iteration,
**extra **extra
)) )
)
self.reload() self.reload()
def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None, def update_and_upload(
tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False, self,
target_filename=None, cb=None, iteration=None): model_file,
""" Update the given model for a given task ID """ design=None,
labels=None,
name=None,
comment=None,
tags=None,
task_id=None,
project_id=None,
parent_id=None,
framework=None,
async_enable=False,
target_filename=None,
cb=None,
iteration=None
):
"""Update the given model for a given task ID"""
if async_enable: if async_enable:
def callback(uploaded_uri): def callback(uploaded_uri):
if uploaded_uri is None: if uploaded_uri is None:
return return
# If not successful, mark model as failed_uploading # If not successful, mark model as failed_uploading
if uploaded_uri is False: if uploaded_uri is False:
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri) uploaded_uri = "{}/failed_uploading".format(
self._upload_storage_uri
)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uploaded_uri) Model._local_model_to_id_uri[str(model_file)] = (
self.model_id,
uploaded_uri,
)
self.update( self.update(
uri=uploaded_uri, uri=uploaded_uri,
@ -281,11 +383,17 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
if cb: if cb:
cb(model_file) cb(model_file)
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, uri = self._upload_model(
cb=callback) model_file,
async_enable=async_enable,
target_filename=target_filename,
cb=callback,
)
return uri return uri
else: else:
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename) uri = self._upload_model(
model_file, async_enable=async_enable, target_filename=target_filename
)
Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri) Model._local_model_to_id_uri[str(model_file)] = (self.model_id, uri)
self.update( self.update(
uri=uri, uri=uri,
@ -302,7 +410,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return uri return uri
def update_for_task(self, task_id, name=None, model_id=None, type_="output", iteration=None): def update_for_task(
self, task_id, name=None, model_id=None, type_="output", iteration=None
):
if Session.check_min_api_version("2.13"): if Session.check_min_api_version("2.13"):
req = tasks.AddOrUpdateModelRequest( req = tasks.AddOrUpdateModelRequest(
task=task_id, name=name, type=type_, model=model_id, iteration=iteration task=task_id, name=name, type=type_, model=model_id, iteration=iteration
@ -314,7 +424,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
# backwards compatibility, None # backwards compatibility, None
req = None req = None
else: else:
raise ValueError("Type '{}' unsupported (use either 'input' or 'output')".format(type_)) raise ValueError(
"Type '{}' unsupported (use either 'input' or 'output')".format(type_)
)
if req: if req:
self.send(req) self.send(req)
@ -323,7 +435,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
@property @property
def model_design(self): def model_design(self):
""" Get the model design. For now, this is stored as a single key in the design dict. """ """Get the model design. For now, this is stored as a single key in the design dict."""
try: try:
return self._unwrap_design(self.data.design) return self._unwrap_design(self.data.design)
except ValueError: except ValueError:
@ -364,7 +476,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
@property @property
def tags(self): def tags(self):
return self.data.system_tags if hasattr(self.data, 'system_tags') else self.data.tags return (
self.data.system_tags
if hasattr(self.data, "system_tags")
else self.data.tags
)
@property @property
def task(self): def task(self):
@ -394,8 +510,10 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:param bool raise_on_error: If True and the artifact could not be downloaded, :param bool raise_on_error: If True and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning. raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the base artifact will be downloaded, :param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached. even if the artifact is already cached.
:return: a local path to a downloaded copy of the model :return: a local path to a downloaded copy of the model
""" """
uri = self.data.uri uri = self.data.uri
@ -403,21 +521,29 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return None return None
# check if we already downloaded the file # check if we already downloaded the file
downloaded_models = [k for k, (i, u) in Model._local_model_to_id_uri.items() if i == self.id and u == uri] downloaded_models = [
k
for k, (i, u) in Model._local_model_to_id_uri.items()
if i == self.id and u == uri
]
for dl_file in downloaded_models: for dl_file in downloaded_models:
if Path(dl_file).exists() and not force_download: if Path(dl_file).exists() and not force_download:
return dl_file return dl_file
# remove non existing model file # remove non existing model file
Model._local_model_to_id_uri.pop(dl_file, None) Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageManager.get_local_copy(uri, extract_archive=False, force_download=force_download) local_download = StorageManager.get_local_copy(
uri, extract_archive=False, force_download=force_download
)
# save local model, so we can later query what was the original one # save local model, so we can later query what was the original one
if local_download is not None: if local_download is not None:
Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri) Model._local_model_to_id_uri[str(local_download)] = (self.model_id, uri)
elif raise_on_error: elif raise_on_error:
raise ValueError("Could not retrieve a local copy of model weights {}, " raise ValueError(
"failed downloading {}".format(self.model_id, uri)) "Could not retrieve a local copy of model weights {}, "
"failed downloading {}".format(self.model_id, uri)
)
return local_download return local_download
@ -426,9 +552,9 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return self._cache_dir return self._cache_dir
def save_model_design_file(self): def save_model_design_file(self):
""" Download model description file into a local file in our cache_dir """ """Download model description file into a local file in our cache_dir"""
design = self.model_design design = self.model_design
filename = self.data.name + '.txt' filename = self.data.name + ".txt"
p = Path(self.cache_dir) / filename p = Path(self.cache_dir) / filename
# we always write the original model design to file, to prevent any mishaps # we always write the original model design to file, to prevent any mishaps
# if p.is_file(): # if p.is_file():
@ -438,11 +564,13 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return str(p) return str(p)
def get_model_package(self): def get_model_package(self):
""" Get a named tuple containing the model's weights and design """ """Get a named tuple containing the model's weights and design"""
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file()) return ModelPackage(
weights=self.download_model_weights(), design=self.save_model_design_file()
)
def get_model_design(self): def get_model_design(self):
""" Get model description (text) """ """Get model description (text)"""
return self.model_design return self.model_design
@classmethod @classmethod
@ -465,8 +593,11 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
data = self.data data = self.data
assert isinstance(data, models.Model) assert isinstance(data, models.Model)
parent = self.id if child else None parent = self.id if child else None
extra = {'system_tags': tags or data.system_tags} \ extra = (
if Session.check_min_api_version('2.3') else {'tags': tags or data.tags} {"system_tags": tags or data.system_tags}
if Session.check_min_api_version("2.3")
else {"tags": tags or data.tags}
)
req = models.CreateRequest( req = models.CreateRequest(
uri=data.uri, uri=data.uri,
name=name, name=name,
@ -485,8 +616,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
def _create_empty_model(self, upload_storage_uri=None, project_id=None): def _create_empty_model(self, upload_storage_uri=None, project_id=None):
upload_storage_uri = upload_storage_uri or self.upload_storage_uri upload_storage_uri = upload_storage_uri or self.upload_storage_uri
name = make_message('Anonymous model %(time)s') name = make_message("Anonymous model %(time)s")
uri = '{}/uploading_file'.format(upload_storage_uri or 'file://') uri = "{}/uploading_file".format(upload_storage_uri or "file://")
req = models.CreateRequest(uri=uri, name=name, labels={}, project=project_id) req = models.CreateRequest(uri=uri, name=name, labels={}, project=project_id)
res = self.send(req) res = self.send(req)
if not res: if not res:

View File

@ -652,7 +652,6 @@ if __name__ == '__main__':
function_source, function_name = CreateFromFunction.__extract_function_information( function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function a_function, sanitize_function=_sanitize_function
) )
# add helper functions on top. # add helper functions on top.
for f in (helper_functions or []): for f in (helper_functions or []):
helper_function_source, _ = CreateFromFunction.__extract_function_information( helper_function_source, _ = CreateFromFunction.__extract_function_information(
@ -665,7 +664,6 @@ if __name__ == '__main__':
if artifact_serialization_function if artifact_serialization_function
else ("", "None") else ("", "None")
) )
artifact_deserialization_function_source, artifact_deserialization_function_name = ( artifact_deserialization_function_source, artifact_deserialization_function_name = (
CreateFromFunction.__extract_function_information(artifact_deserialization_function) CreateFromFunction.__extract_function_information(artifact_deserialization_function)
if artifact_deserialization_function if artifact_deserialization_function
@ -833,7 +831,5 @@ if __name__ == '__main__':
function_source = inspect.getsource(function) function_source = inspect.getsource(function)
if sanitize_function: if sanitize_function:
function_source = sanitize_function(function_source) function_source = sanitize_function(function_source)
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source) function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
return function_source, function_name return function_source, function_name

View File

@ -255,10 +255,14 @@ class ScriptRequirements(object):
@staticmethod @staticmethod
def _remove_package_versions(installed_pkgs, package_names_to_remove_version): def _remove_package_versions(installed_pkgs, package_names_to_remove_version):
installed_pkgs = {k: (v[0], None if str(k) in package_names_to_remove_version else v[1]) def _internal(_installed_pkgs):
for k, v in installed_pkgs.items()} return {
k: (v[0], None if str(k) in package_names_to_remove_version else v[1])
if not isinstance(v, dict) else _internal(v)
for k, v in _installed_pkgs.items()
}
return installed_pkgs return _internal(installed_pkgs)
class _JupyterObserver(object): class _JupyterObserver(object):
@ -781,6 +785,7 @@ class ScriptInfo(object):
try: try:
# we expect to find boto3 in the sagemaker env # we expect to find boto3 in the sagemaker env
import boto3 import boto3
with open(cls._sagemaker_metadata_path) as f: with open(cls._sagemaker_metadata_path) as f:
notebook_data = json.load(f) notebook_data = json.load(f)
client = boto3.client("sagemaker") client = boto3.client("sagemaker")
@ -799,7 +804,6 @@ class ScriptInfo(object):
return jupyter_session.get("path", ""), jupyter_session.get("name", "") return jupyter_session.get("path", ""), jupyter_session.get("name", "")
except Exception as e: except Exception as e:
cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e)) cls._get_logger().warning("Failed finding Notebook in SageMaker environment. Error is: '{}'".format(e))
return None, None return None, None
@classmethod @classmethod

View File

@ -700,7 +700,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
the text will not be printed, because the Python process is immediately terminated. the text will not be printed, because the Python process is immediately terminated.
:param bool ignore_errors: If True default), ignore any errors raised :param bool ignore_errors: If True (default), ignore any errors raised
:param bool force: If True, the task status will be changed to `stopped` regardless of the current Task state. :param bool force: If True, the task status will be changed to `stopped` regardless of the current Task state.
:param str status_message: Optional, add status change message to the stop request. :param str status_message: Optional, add status change message to the stop request.
This message will be stored as status_message on the Task's info panel This message will be stored as status_message on the Task's info panel
@ -718,11 +718,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
tasks.CompletedRequest( tasks.CompletedRequest(
self.id, status_reason='completed', status_message=status_message, force=force), self.id, status_reason='completed', status_message=status_message, force=force),
ignore_errors=ignore_errors) ignore_errors=ignore_errors)
if self._get_runtime_properties().get("_publish_on_complete"): if self._get_runtime_properties().get("_publish_on_complete"):
self.send( self.send(
tasks.PublishRequest( tasks.PublishRequest(
self.id, status_reason='completed', status_message=status_message, force=force), self.id, status_reason='completed', status_message=status_message, force=force),
ignore_errors=ignore_errors) ignore_errors=ignore_errors)
return resp return resp
return self.send( return self.send(
tasks.StoppedRequest(self.id, status_reason='completed', status_message=status_message, force=force), tasks.StoppedRequest(self.id, status_reason='completed', status_message=status_message, force=force),
@ -2387,7 +2389,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return True return True
def _get_runtime_properties(self): def _get_runtime_properties(self):
# type: () -> Mapping[str, str] # type: () -> Dict[str, str]
if not Session.check_min_api_version('2.13'): if not Session.check_min_api_version('2.13'):
return dict() return dict()
return dict(**self.data.runtime) if self.data.runtime else dict() return dict(**self.data.runtime) if self.data.runtime else dict()

View File

@ -93,7 +93,7 @@ def main():
# Take the credentials in raw form or from api section # Take the credentials in raw form or from api section
credentials = get_parsed_field(parsed, ["credentials"]) credentials = get_parsed_field(parsed, ["credentials"])
api_server = get_parsed_field(parsed, ["api_server", "host"]) api_server = get_parsed_field(parsed, ["api_server", "host"])
web_server = get_parsed_field(parsed, ["web_server"]) web_server = get_parsed_field(parsed, ["web_server"]) # TODO: if previous fails, this will fail too
files_server = get_parsed_field(parsed, ["files_server"]) files_server = get_parsed_field(parsed, ["files_server"])
except Exception: except Exception:
credentials = credentials or None credentials = credentials or None

View File

@ -17,7 +17,6 @@ from attr import attrs, attrib
from pathlib2 import Path from pathlib2 import Path
from .. import Task, StorageManager, Logger from .. import Task, StorageManager, Logger
from ..backend_api.session.client import APIClient
from ..backend_api import Session from ..backend_api import Session
from ..backend_interface.task.development.worker import DevWorker from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project from ..backend_interface.util import mutually_exclusive, exact_match_regex, get_or_create_project, rename_project

View File

@ -359,6 +359,7 @@ class Logger(object):
iteration=0, iteration=0,
table_plot=df, table_plot=df,
extra_data={'columnwidth': [2., 1., 1., 1.]}) extra_data={'columnwidth': [2., 1., 1., 1.]})
""" """
mutually_exclusive( mutually_exclusive(
UsageError, _check_none=True, UsageError, _check_none=True,

File diff suppressed because it is too large Load Diff

View File

@ -1177,8 +1177,7 @@ class StorageHelper(object):
def _do_async_upload(self, data): def _do_async_upload(self, data):
assert isinstance(data, self._UploadData) assert isinstance(data, self._UploadData)
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, verbose=True, retries=data.retries, return_canonized=data.return_canonized)
verbose=True, retries=data.retries, return_canonized=data.return_canonized)
def _upload_from_file(self, local_path, dest_path, extra=None): def _upload_from_file(self, local_path, dest_path, extra=None):
if not hasattr(self._driver, 'upload_object'): if not hasattr(self._driver, 'upload_object'):
@ -1473,7 +1472,6 @@ class _HttpDriver(_Driver):
try: try:
container = self.get_container(container_name) container = self.get_container(container_name)
url = container_name + object_name url = container_name + object_name
return container.session.head(url, allow_redirects=True, headers=container.get_headers(url)).ok return container.session.head(url, allow_redirects=True, headers=container.get_headers(url)).ok
except Exception: except Exception:
return False return False

View File

@ -1,5 +1,4 @@
import fnmatch import fnmatch
import os
import shutil import shutil
import tarfile import tarfile
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
@ -7,7 +6,6 @@ from random import random
from time import time from time import time
from typing import List, Optional, Union from typing import List, Optional, Union
from zipfile import ZipFile from zipfile import ZipFile
from six.moves.urllib.parse import urlparse
from pathlib2 import Path from pathlib2 import Path
@ -304,8 +302,8 @@ class StorageManager(object):
if not local_folder: if not local_folder:
local_folder = CacheManager.get_cache_manager().get_cache_folder() local_folder = CacheManager.get_cache_manager().get_cache_folder()
local_path = str(Path(local_folder).expanduser().absolute() / bucket_path) local_path = str(Path(local_folder).expanduser().absolute() / bucket_path)
helper = StorageHelper.get(remote_url) helper = StorageHelper.get(remote_url)
return helper.download_to_file( return helper.download_to_file(
remote_url, remote_url,
local_path, local_path,

View File

@ -1988,7 +1988,9 @@ class Task(_Task):
corresponding to debug sample's file name in the UI, also known as variant corresponding to debug sample's file name in the UI, also known as variant
:param int n_last_iterations: How many debug samples iterations to fetch in reverse chronological order. :param int n_last_iterations: How many debug samples iterations to fetch in reverse chronological order.
Leave empty to get all debug samples. Leave empty to get all debug samples.
:raise: TypeError if `n_last_iterations` is explicitly set to anything other than a positive integer value :raise: TypeError if `n_last_iterations` is explicitly set to anything other than a positive integer value
:return: A list of `dict`s, each dictionary containing the debug sample's URL and other metadata. :return: A list of `dict`s, each dictionary containing the debug sample's URL and other metadata.
The URLs can be passed to :meth:`StorageManager.get_local_copy` to fetch local copies of debug samples. The URLs can be passed to :meth:`StorageManager.get_local_copy` to fetch local copies of debug samples.
""" """
@ -2021,11 +2023,14 @@ class Task(_Task):
def _get_debug_samples(self, title, series, n_last_iterations=None): def _get_debug_samples(self, title, series, n_last_iterations=None):
response = self._send_debug_image_request(title, series, n_last_iterations) response = self._send_debug_image_request(title, series, n_last_iterations)
debug_samples = [] debug_samples = []
while True: while True:
scroll_id = response.response.scroll_id scroll_id = response.response_data.get("scroll_id", None)
for metric_resp in response.response.metrics:
iterations_events = [iteration["events"] for iteration in metric_resp.iterations] # type: List[List[dict]] for metric_resp in response.response_data.get("metrics", []):
iterations_events = [iteration["events"] for iteration in metric_resp.get("iterations", [])] # type: List[List[dict]]
flattened_events = (event flattened_events = (event
for single_iter_events in iterations_events for single_iter_events in iterations_events
for event in single_iter_events) for event in single_iter_events)
@ -2037,8 +2042,8 @@ class Task(_Task):
if (len(debug_samples) == n_last_iterations if (len(debug_samples) == n_last_iterations
or all( or all(
len(metric_resp.iterations) == 0 len(metric_resp.get("iterations", [])) == 0
for metric_resp in response.response.metrics)): for metric_resp in response.response_data.get("metrics", []))):
break break
return debug_samples return debug_samples
@ -2877,13 +2882,11 @@ class Task(_Task):
Set offline mode, where all data and logs are stored into local folder, for later transmission Set offline mode, where all data and logs are stored into local folder, for later transmission
.. note:: .. note::
`Task.set_offline` can't move the same task from offline to online, nor can it be applied before `Task.create`. `Task.set_offline` can't move the same task from offline to online, nor can it be applied before `Task.create`.
See below an example of **incorect** usage of `Task.set_offline`: See below an example of **incorect** usage of `Task.set_offline`:
.. code-block:: py .. code-block:: py
from clearml import Task from clearml import Task
Task.set_offline(True) Task.set_offline(True)
task = Task.create(project_name='DEBUG', task_name="offline") task = Task.create(project_name='DEBUG', task_name="offline")
# ^^^ an error or warning is emitted, telling us that `Task.set_offline(True)` # ^^^ an error or warning is emitted, telling us that `Task.set_offline(True)`
@ -2891,23 +2894,25 @@ class Task(_Task):
Task.set_offline(False) Task.set_offline(False)
# ^^^ an error or warning is emitted, telling us that running `Task.set_offline(False)` # ^^^ an error or warning is emitted, telling us that running `Task.set_offline(False)`
# while the current task is not closed is not something we support # while the current task is not closed is not something we support
data = task.export_task() data = task.export_task()
imported_task = Task.import_task(task_data=data) imported_task = Task.import_task(task_data=data)
The correct way to use `Task.set_offline` can be seen in the following example: The correct way to use `Task.set_offline` can be seen in the following example:
.. code-block:: py .. code-block:: py
from clearml import Task from clearml import Task
Task.set_offline(True) Task.set_offline(True)
task = Task.init(project_name='DEBUG', task_name="offline") task = Task.init(project_name='DEBUG', task_name="offline")
task.upload_artifact("large_artifact", "test_strign") task.upload_artifact("large_artifact", "test_strign")
task.close() task.close()
Task.set_offline(False) Task.set_offline(False)
imported_task = Task.import_offline_session(task.get_offline_mode_folder()) imported_task = Task.import_offline_session(task.get_offline_mode_folder())
:param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled. :param offline_mode: If True, offline-mode is turned on, and no communication to the backend is enabled.
:return: :return:
""" """
if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode: if running_remotely() or bool(offline_mode) == InterfaceBase._offline_mode:
@ -2932,6 +2937,7 @@ class Task(_Task):
# type: () -> bool # type: () -> bool
""" """
Return offline-mode state, If in offline-mode, no communication to the backend is enabled. Return offline-mode state, If in offline-mode, no communication to the backend is enabled.
:return: boolean offline-mode state :return: boolean offline-mode state
""" """
return cls._offline_mode return cls._offline_mode
@ -3542,11 +3548,9 @@ class Task(_Task):
def _check_keys(dict_, warning_sent=False): def _check_keys(dict_, warning_sent=False):
if warning_sent: if warning_sent:
return return
for k, v in dict_.items(): for k, v in dict_.items():
if warning_sent: if warning_sent:
return return
if not isinstance(k, str): if not isinstance(k, str):
getLogger().warning( getLogger().warning(
"Unsupported key of type '{}' found when connecting dictionary. It will be converted to str".format( "Unsupported key of type '{}' found when connecting dictionary. It will be converted to str".format(
@ -3554,12 +3558,10 @@ class Task(_Task):
) )
) )
warning_sent = True warning_sent = True
if isinstance(v, dict): if isinstance(v, dict):
_check_keys(v, warning_sent) _check_keys(v, warning_sent)
if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()): if not running_remotely() or not (self.is_main_task() or self._is_remote_main_task()):
self._arguments.copy_from_dict(flatten_dictionary(dictionary), prefix=name)
_check_keys(dictionary) _check_keys(dictionary)
flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()} flat_dict = {str(k): v for k, v in flatten_dictionary(dictionary).items()}
self._arguments.copy_from_dict(flat_dict, prefix=name) self._arguments.copy_from_dict(flat_dict, prefix=name)
@ -3909,7 +3911,6 @@ class Task(_Task):
try: try:
# make sure the state of the offline data is saved # make sure the state of the offline data is saved
self._edit() self._edit()
# create zip file # create zip file
offline_folder = self.get_offline_mode_folder() offline_folder = self.get_offline_mode_folder()
zip_file = offline_folder.as_posix() + '.zip' zip_file = offline_folder.as_posix() + '.zip'

View File

@ -223,11 +223,18 @@
"toc_visible": true "toc_visible": true
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": ".venv",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
"name": "python" "name": "python",
"version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]"
},
"vscode": {
"interpreter": {
"hash": "8b483fbf9fa60c6c6195634afd5159f586a30c5c6a9d31fa17f93a17f02fdc40"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -451,4 +451,5 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }

View File

@ -17,6 +17,7 @@ from enum import Enum
from clearml import Task from clearml import Task
from clearml.task_parameters import TaskParameters, param, percent_param from clearml.task_parameters import TaskParameters, param, percent_param
# Connecting ClearML with the current process, # Connecting ClearML with the current process,
# from here on everything is logged automatically # from here on everything is logged automatically
task = Task.init(project_name='FirstTrial', task_name='first_trial') task = Task.init(project_name='FirstTrial', task_name='first_trial')
@ -43,7 +44,6 @@ class IntEnumClass(Enum):
C = 1 C = 1
D = 2 D = 2
parameters = { parameters = {
'list': [1, 2, 3], 'list': [1, 2, 3],
'dict': {'a': 1, 'b': 2}, 'dict': {'a': 1, 'b': 2},