Remove deprecated function call StorageHelper._test_bucket_config()

This commit is contained in:
allegroai 2020-05-31 11:55:58 +03:00
parent 0a0d816bd5
commit 7440799bb0

View File

@ -7,8 +7,10 @@ from enum import Enum
from tempfile import gettempdir from tempfile import gettempdir
from multiprocessing import RLock from multiprocessing import RLock
from threading import Thread from threading import Thread
from typing import Optional, Any, Sequence, Callable, Mapping, Union
try: try:
# noinspection PyCompatibility
from collections.abc import Iterable from collections.abc import Iterable
except ImportError: except ImportError:
from collections import Iterable from collections import Iterable
@ -202,8 +204,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self.log.warning(msg) self.log.warning(msg)
if raise_errors: if raise_errors:
raise Exception(msg) raise Exception(msg)
else:
StorageHelper._test_bucket_config(conf=conf, log=self.log, raise_on_error=raise_errors)
except StorageError: except StorageError:
raise raise
except Exception as ex: except Exception as ex:
@ -220,6 +220,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _update_repository(self): def _update_repository(self):
def check_package_update(): def check_package_update():
# noinspection PyBroadException
try: try:
# check latest version # check latest version
from ...utilities.check_updates import CheckPackageUpdates from ...utilities.check_updates import CheckPackageUpdates
@ -229,7 +230,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
sep = os.linesep sep = os.linesep
self.get_logger().report_text( self.get_logger().report_text(
'{} new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format( '{} new package available: UPGRADE to v{} is recommended!\nRelease Notes:\n{}'.format(
Session._client[0][0].upper(), latest_version[0], sep.join(latest_version[2])), Session.get_clients()[0][0].upper(), latest_version[0], sep.join(latest_version[2])),
) )
else: else:
self.get_logger().report_text( self.get_logger().report_text(
@ -310,6 +311,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def storage_uri(self): def storage_uri(self):
# type: () -> Optional[str]
if self._storage_uri: if self._storage_uri:
return self._storage_uri return self._storage_uri
if running_remotely(): if running_remotely():
@ -319,55 +321,68 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@storage_uri.setter @storage_uri.setter
def storage_uri(self, value): def storage_uri(self, value):
# type: (str) -> ()
self._set_storage_uri(value) self._set_storage_uri(value)
@property @property
def task_id(self): def task_id(self):
# type: () -> str
return self.id return self.id
@property @property
def name(self): def name(self):
# type: () -> str
return self.data.name or '' return self.data.name or ''
@name.setter @name.setter
def name(self, value): def name(self, value):
# type: (str) -> ()
self.set_name(value) self.set_name(value)
@property @property
def task_type(self): def task_type(self):
# type: () -> str
return self.data.type return self.data.type
@property @property
def project(self): def project(self):
# type: () -> str
return self.data.project return self.data.project
@property @property
def parent(self): def parent(self):
# type: () -> str
return self.data.parent return self.data.parent
@property @property
def input_model_id(self): def input_model_id(self):
# type: () -> str
return self.data.execution.model return self.data.execution.model
@property @property
def output_model_id(self): def output_model_id(self):
# type: () -> str
return self.data.output.model return self.data.output.model
@property @property
def comment(self): def comment(self):
# type: () -> str
return self.data.comment or '' return self.data.comment or ''
@comment.setter @comment.setter
def comment(self, value): def comment(self, value):
# type: (str) -> ()
self.set_comment(value) self.set_comment(value)
@property @property
def cache_dir(self): def cache_dir(self):
# type: () -> Path
""" The cache directory which is used to store the Task related files. """ """ The cache directory which is used to store the Task related files. """
return Path(get_cache_dir()) / self.id return Path(get_cache_dir()) / self.id
@property @property
def status(self): def status(self):
# type: () -> str
""" """
The Task's status. To keep the Task updated. The Task's status. To keep the Task updated.
Trains reloads the Task status information only, when this value is accessed. Trains reloads the Task status information only, when this value is accessed.
@ -378,11 +393,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def _status(self): def _status(self):
# type: () -> str
""" Return the task's cached status (don't reload if we don't have to) """ """ Return the task's cached status (don't reload if we don't have to) """
return str(self.data.status) return str(self.data.status)
@property @property
def input_model(self): def input_model(self):
# type: () -> Optional[Model]
""" A model manager used to handle the input model object """ """ A model manager used to handle the input model object """
model_id = self._get_task_property('execution.model', raise_on_error=False) model_id = self._get_task_property('execution.model', raise_on_error=False)
if not model_id: if not model_id:
@ -398,15 +415,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def output_model(self): def output_model(self):
# type: () -> Optional[Model]
""" A model manager used to manage the output model object """ """ A model manager used to manage the output model object """
if self._output_model is None: if self._output_model is None:
self._output_model = self._get_output_model(upload_required=True) self._output_model = self._get_output_model(upload_required=True)
return self._output_model return self._output_model
def create_output_model(self): def create_output_model(self):
# type: () -> Model
return self._get_output_model(upload_required=False, force=True) return self._get_output_model(upload_required=False, force=True)
def _get_output_model(self, upload_required=True, force=False): def _get_output_model(self, upload_required=True, force=False):
# type: (bool, bool) -> Model
return Model( return Model(
session=self.session, session=self.session,
model_id=None if force else self._get_task_property( model_id=None if force else self._get_task_property(
@ -419,11 +439,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def metrics_manager(self): def metrics_manager(self):
# type: () -> Metrics
""" A metrics manager used to manage the metrics related to this task """ """ A metrics manager used to manage the metrics related to this task """
return self._get_metrics_manager(self.get_output_destination()) return self._get_metrics_manager(self.get_output_destination())
@property @property
def reporter(self): def reporter(self):
# type: () -> Reporter
""" """
Returns a simple metrics reporter instance Returns a simple metrics reporter instance
""" """
@ -432,6 +454,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._reporter return self._reporter
def _get_metrics_manager(self, storage_uri): def _get_metrics_manager(self, storage_uri):
# type: (str) -> Metrics
if self._metrics_manager is None: if self._metrics_manager is None:
self._metrics_manager = Metrics( self._metrics_manager = Metrics(
session=self.session, session=self.session,
@ -443,6 +466,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._metrics_manager return self._metrics_manager
def _setup_reporter(self): def _setup_reporter(self):
# type: () -> Reporter
try: try:
storage_uri = self.get_output_destination(log_on_error=False) storage_uri = self.get_output_destination(log_on_error=False)
except ValueError: except ValueError:
@ -451,10 +475,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._reporter return self._reporter
def _get_output_destination_suffix(self, extra_path=None): def _get_output_destination_suffix(self, extra_path=None):
# type: (Optional[str]) -> str
return '/'.join(quote(x, safe="'[]{}()$^,.; -_+-=") for x in return '/'.join(quote(x, safe="'[]{}()$^,.; -_+-=") for x in
(self.get_project_name(), '%s.%s' % (self.name, self.data.id), extra_path) if x) (self.get_project_name(), '%s.%s' % (self.name, self.data.id), extra_path) if x)
def _reload(self): def _reload(self):
# type: () -> Any
""" Reload the task object from the backend """ """ Reload the task object from the backend """
with self._edit_lock: with self._edit_lock:
if self._reload_skip_flag and self._data: if self._reload_skip_flag and self._data:
@ -463,6 +489,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return res.response.task return res.response.task
def reset(self, set_started_on_success=True): def reset(self, set_started_on_success=True):
# type: (bool) -> ()
""" Reset the task. Task will be reloaded following a successful reset. """ """ Reset the task. Task will be reloaded following a successful reset. """
self.send(tasks.ResetRequest(task=self.id)) self.send(tasks.ResetRequest(task=self.id))
if set_started_on_success: if set_started_on_success:
@ -474,25 +501,30 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self.reload() self.reload()
def started(self, ignore_errors=True): def started(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal that this Task started. """ """ The signal that this Task started. """
return self.send(tasks.StartedRequest(self.id), ignore_errors=ignore_errors) return self.send(tasks.StartedRequest(self.id), ignore_errors=ignore_errors)
def stopped(self, ignore_errors=True): def stopped(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal that this Task stopped. """ """ The signal that this Task stopped. """
return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors) return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors)
def completed(self, ignore_errors=True): def completed(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal indicating that this Task completed. """ """ The signal indicating that this Task completed. """
if hasattr(tasks, 'CompletedRequest'): if hasattr(tasks, 'CompletedRequest') and callable(tasks.CompletedRequest):
return self.send(tasks.CompletedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors) return self.send(tasks.CompletedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
return self.send(tasks.StoppedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors) return self.send(tasks.StoppedRequest(self.id, status_reason='completed'), ignore_errors=ignore_errors)
def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None): def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None):
# type: (bool, Optional[str], Optional[str]) -> ()
""" The signal that this Task stopped. """ """ The signal that this Task stopped. """
return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message), return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message),
ignore_errors=ignore_errors) ignore_errors=ignore_errors)
def publish(self, ignore_errors=True): def publish(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal that this Task will be published """ """ The signal that this Task will be published """
if str(self.status) != str(tasks.TaskStatusEnum.stopped): if str(self.status) != str(tasks.TaskStatusEnum.stopped):
raise ValueError("Can't publish, Task is not stopped") raise ValueError("Can't publish, Task is not stopped")
@ -501,6 +533,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return resp return resp
def update_model_desc(self, new_model_desc_file=None): def update_model_desc(self, new_model_desc_file=None):
# type: (Optional[str]) -> ()
""" Change the Task's model description. """ """ Change the Task's model description. """
with self._edit_lock: with self._edit_lock:
self.reload() self.reload()
@ -516,6 +549,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return res.response return res.response
def update_output_model(self, model_uri, name=None, comment=None, tags=None): def update_output_model(self, model_uri, name=None, comment=None, tags=None):
# type: (str, Optional[str], Optional[str], Optional[Sequence[str]]) -> ()
""" """
Update the Task's output model. Use this method to update the output model when you have a local model URI, Update the Task's output model. Use this method to update the output model when you have a local model URI,
for example, storing the weights file locally, and specifying a ``file://path/to/file`` URI) for example, storing the weights file locally, and specifying a ``file://path/to/file`` URI)
@ -536,33 +570,38 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._get_output_model(upload_required=False).update_for_task(model_uri, self.id, name, comment, tags) self._get_output_model(upload_required=False).update_for_task(model_uri, self.id, name, comment, tags)
def update_output_model_and_upload( def update_output_model_and_upload(
self, model_file, name=None, comment=None, tags=None, async_enable=False, cb=None, iteration=None): self,
model_file, # type: str
name=None, # type: Optional[str]
comment=None, # type: Optional[str]
tags=None, # type: Optional[Sequence[str]]
async_enable=False, # type: bool
cb=None, # type: Optional[Callable[[Optional[bool]], bool]]
iteration=None, # type: Optional[int]
):
# type: (...) -> str
""" """
Update the Task's output model weights file. First, Trains uploads the file to the preconfigured output Update the Task's output model weights file. First, Trains uploads the file to the preconfigured output
destination (see the Task's ``output.destination`` property or call the ``setup_upload()`` method), destination (see the Task's ``output.destination`` property or call the ``setup_upload()`` method),
then Trains updates the model object associated with the Task an API call. The API call uses with the URI then Trains updates the model object associated with the Task an API call. The API call uses with the URI
of the uploaded file, and other values provided by additional arguments. of the uploaded file, and other values provided by additional arguments.
:param model_file: The path to the updated model weights file. :param str model_file: The path to the updated model weights file.
:type model_file: str :param str name: The updated model name. (Optional)
:param name: The updated model name. (Optional) :param str comment: The updated model description. (Optional)
:type name: str :param list tags: The updated model tags. (Optional)
:param comment: The updated model description. (Optional) :param bool async_enable: Request asynchronous upload?
:type comment: str
:param tags: The updated model tags. (Optional)
:type tags: [str]
:param async_enable: Request asynchronous upload?
- ``True`` - The API call returns immediately, while the upload and update are scheduled in another thread. - ``True`` - The API call returns immediately, while the upload and update are scheduled in another thread.
- ``False`` - The API call blocks until the upload completes, and the API call updating the model returns. - ``False`` - The API call blocks until the upload completes, and the API call updating the model returns.
(Default) (Default)
:type async_enable: bool :param callable cb: Asynchronous callback. A callback. If ``async_enable`` is set to ``True``,
:param cb: Asynchronous callback. A callback. If ``async_enable`` is set to ``True``, this is a callback that this is a callback that is invoked once the asynchronous upload and update complete.
is invoked once the asynchronous upload and update complete. :param int iteration: iteration number for the current stored model (Optional)
:return: The URI of the uploaded weights file. If ``async_enable`` is set to ``True``, this is the expected URI, :return str: The URI of the uploaded weights file. If ``async_enable`` is set to ``True``,
as the upload is probably still in progress. this is the expected URI, as the upload is probably still in progress.
""" """
self._conditionally_start_task() self._conditionally_start_task()
uri = self.output_model.update_for_task_and_upload( uri = self.output_model.update_for_task_and_upload(
@ -572,15 +611,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return uri return uri
def _conditionally_start_task(self): def _conditionally_start_task(self):
# type: () -> ()
if str(self.status) == str(tasks.TaskStatusEnum.created): if str(self.status) == str(tasks.TaskStatusEnum.created):
self.started() self.started()
@property @property
def labels_stats(self): def labels_stats(self):
# type: () -> dict
""" Get accumulated label stats for the current/last frames iteration """ """ Get accumulated label stats for the current/last frames iteration """
return self._curr_label_stats return self._curr_label_stats
def _accumulate_label_stats(self, roi_stats, reset=False): def _accumulate_label_stats(self, roi_stats, reset=False):
# type: (dict, bool) -> ()
if reset: if reset:
self._curr_label_stats = {} self._curr_label_stats = {}
for label in roi_stats: for label in roi_stats:
@ -590,6 +632,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._curr_label_stats[label] = roi_stats[label] self._curr_label_stats[label] = roi_stats[label]
def set_input_model(self, model_id=None, model_name=None, update_task_design=True, update_task_labels=True): def set_input_model(self, model_id=None, model_name=None, update_task_design=True, update_task_labels=True):
# type: (str, Optional[str], bool, bool) -> ()
""" """
Set a new input model for the Task. The model must be "ready" (status is ``Published``) to be used as the Set a new input model for the Task. The model must be "ready" (status is ``Published``) to be used as the
Task's input model. Task's input model.
@ -651,6 +694,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=self.data.execution) self._edit(execution=self.data.execution)
def set_parameters(self, *args, **kwargs): def set_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> ()
""" """
Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not Set the parameters for a Task. This method sets a complete group of key-value parameter pairs, but does not
support parameter descriptions (the input is a dictionary of key-value pairs). support parameter descriptions (the input is a dictionary of key-value pairs).
@ -696,6 +740,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=execution) self._edit(execution=execution)
def set_parameter(self, name, value, description=None): def set_parameter(self, name, value, description=None):
# type: (str, str, Optional[str]) -> ()
""" """
Set a single Task parameter. This overrides any previous value for this parameter. Set a single Task parameter. This overrides any previous value for this parameter.
@ -706,9 +751,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
.. note:: .. note::
The ``description`` is not yet in use. The ``description`` is not yet in use.
""" """
# not supported yet
if description:
# noinspection PyUnusedLocal
description = None
self.set_parameters({name: value}, __update=True) self.set_parameters({name: value}, __update=True)
def get_parameter(self, name, default=None): def get_parameter(self, name, default=None):
# type: (str, Any) -> Any
""" """
Get a value for a parameter. Get a value for a parameter.
@ -720,6 +770,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return params.get(name, default) return params.get(name, default)
def update_parameters(self, *args, **kwargs): def update_parameters(self, *args, **kwargs):
# type: (*dict, **Any) -> ()
""" """
Update the parameters for a Task. This method updates a complete group of key-value parameter pairs, but does Update the parameters for a Task. This method updates a complete group of key-value parameter pairs, but does
not support parameter descriptions (the input is a dictionary of key-value pairs). not support parameter descriptions (the input is a dictionary of key-value pairs).
@ -731,6 +782,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self.set_parameters(__update=True, *args, **kwargs) self.set_parameters(__update=True, *args, **kwargs)
def set_model_label_enumeration(self, enumeration=None): def set_model_label_enumeration(self, enumeration=None):
# type: (Mapping[str, int]) -> ()
""" """
Set a dictionary of labels (text) to ids (integers) {str(label): integer(id)} Set a dictionary of labels (text) to ids (integers) {str(label): integer(id)}
@ -749,11 +801,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=execution) self._edit(execution=execution)
def _set_default_docker_image(self): def _set_default_docker_image(self):
# type: () -> ()
if not DOCKER_IMAGE_ENV_VAR.exists(): if not DOCKER_IMAGE_ENV_VAR.exists():
return return
self.set_base_docker(DOCKER_IMAGE_ENV_VAR.get(default="")) self.set_base_docker(DOCKER_IMAGE_ENV_VAR.get(default=""))
def set_base_docker(self, docker_cmd): def set_base_docker(self, docker_cmd):
# type: (str) -> ()
""" """
Set the base docker image for this experiment Set the base docker image for this experiment
If provided, this value will be used by trains-agent to execute this experiment If provided, this value will be used by trains-agent to execute this experiment
@ -766,10 +820,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=execution) self._edit(execution=execution)
def get_base_docker(self): def get_base_docker(self):
# type: () -> str
"""Get the base Docker command (image) that is set for this experiment.""" """Get the base Docker command (image) that is set for this experiment."""
return self._get_task_property('execution.docker_cmd', raise_on_error=False, log_on_error=False) return self._get_task_property('execution.docker_cmd', raise_on_error=False, log_on_error=False)
def set_artifacts(self, artifacts_list=None): def set_artifacts(self, artifacts_list=None):
# type: (Sequence[tasks.Artifact]) -> ()
""" """
List of artifacts (tasks.Artifact) to update the task List of artifacts (tasks.Artifact) to update the task
@ -788,15 +844,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(execution=execution) self._edit(execution=execution)
def _set_model_design(self, design=None): def _set_model_design(self, design=None):
# type: (str) -> ()
with self._edit_lock: with self._edit_lock:
self.reload() self.reload()
execution = self.data.execution execution = self.data.execution
if design is not None: if design is not None:
# noinspection PyProtectedMember
execution.model_desc = Model._wrap_design(design) execution.model_desc = Model._wrap_design(design)
self._edit(execution=execution) self._edit(execution=execution)
def get_labels_enumeration(self): def get_labels_enumeration(self):
# type: () -> Mapping[str, int]
""" """
Get the label enumeration dictionary label enumeration dictionary of string (label) to integer (value) pairs. Get the label enumeration dictionary label enumeration dictionary of string (label) to integer (value) pairs.
@ -807,32 +866,39 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self.data.execution.model_labels return self.data.execution.model_labels
def get_model_design(self): def get_model_design(self):
# type: () -> str
""" """
Get the model configuration as blob of text. Get the model configuration as blob of text.
:return: :return:
""" """
design = self._get_task_property("execution.model_desc", default={}, raise_on_error=False, log_on_error=False) design = self._get_task_property("execution.model_desc", default={}, raise_on_error=False, log_on_error=False)
# noinspection PyProtectedMember
return Model._unwrap_design(design) return Model._unwrap_design(design)
def set_output_model_id(self, model_id): def set_output_model_id(self, model_id):
# type: (str) -> ()
self.data.output.model = str(model_id) self.data.output.model = str(model_id)
self._edit(output=self.data.output) self._edit(output=self.data.output)
def get_random_seed(self): def get_random_seed(self):
# type: () -> int
# fixed seed for the time being # fixed seed for the time being
return 1337 return 1337
def set_random_seed(self, random_seed): def set_random_seed(self, random_seed):
# type: (int) -> ()
# fixed seed for the time being # fixed seed for the time being
pass pass
def set_project(self, project_id): def set_project(self, project_id):
# type: (str) -> ()
assert isinstance(project_id, six.string_types) assert isinstance(project_id, six.string_types)
self._set_task_property("project", project_id) self._set_task_property("project", project_id)
self._edit(project=project_id) self._edit(project=project_id)
def get_project_name(self): def get_project_name(self):
# type: () -> Optional[str]
if self.project is None: if self.project is None:
return None return None
@ -846,9 +912,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._project_name[1] return self._project_name[1]
def get_tags(self): def get_tags(self):
# type: () -> Sequence[str]
return self._get_task_property("tags") return self._get_task_property("tags")
def set_system_tags(self, tags): def set_system_tags(self, tags):
# type: (Sequence[str]) -> ()
assert isinstance(tags, (list, tuple)) assert isinstance(tags, (list, tuple))
if Session.check_min_api_version('2.3'): if Session.check_min_api_version('2.3'):
self._set_task_property("system_tags", tags) self._set_task_property("system_tags", tags)
@ -858,9 +926,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(tags=self.data.tags) self._edit(tags=self.data.tags)
def get_system_tags(self): def get_system_tags(self):
# type: () -> Sequence[str]
return self._get_task_property("system_tags" if Session.check_min_api_version('2.3') else "tags") return self._get_task_property("system_tags" if Session.check_min_api_version('2.3') else "tags")
def set_tags(self, tags): def set_tags(self, tags):
# type: (Sequence[str]) -> ()
assert isinstance(tags, (list, tuple)) assert isinstance(tags, (list, tuple))
if not Session.check_min_api_version('2.3'): if not Session.check_min_api_version('2.3'):
# not supported # not supported
@ -869,6 +939,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(tags=self.data.tags) self._edit(tags=self.data.tags)
def set_name(self, name): def set_name(self, name):
# type: (str) -> ()
""" """
Set the Task name. Set the Task name.
@ -879,6 +950,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(name=self.data.name) self._edit(name=self.data.name)
def set_comment(self, comment): def set_comment(self, comment):
# type: (str) -> ()
""" """
Set a comment / description for the Task. Set a comment / description for the Task.
@ -889,6 +961,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._edit(comment=comment) self._edit(comment=comment)
def set_initial_iteration(self, offset=0): def set_initial_iteration(self, offset=0):
# type: (int) -> int
""" """
Set the initial iteration offset. The default value is ``0``. This method is useful when continuing training Set the initial iteration offset. The default value is ``0``. This method is useful when continuing training
from previous checkpoints. from previous checkpoints.
@ -913,6 +986,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._initial_iteration_offset return self._initial_iteration_offset
def get_initial_iteration(self): def get_initial_iteration(self):
# type: () -> int
""" """
Get the initial iteration offset. The default value is ``0``. This method is useful when continuing training Get the initial iteration offset. The default value is ``0``. This method is useful when continuing training
from previous checkpoints. from previous checkpoints.
@ -924,6 +998,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return self._initial_iteration_offset return self._initial_iteration_offset
def get_status(self): def get_status(self):
# type: () -> str
""" """
Return The task status without refreshing the entire Task object object (only the status property) Return The task status without refreshing the entire Task object object (only the status property)
@ -937,7 +1012,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._data.status = status self._data.status = status
return str(status) return str(status)
def get_reported_scalars(self, max_samples=0, x_axis='iter'): def get_reported_scalars(
self,
max_samples=0, # type: int
x_axis='iter' # type: Union['iter', 'timestamp', 'iso_time']
):
# type: (...) -> Mapping[str, Mapping[str, Mapping[str, Sequence[float]]]]
""" """
Return a nested dictionary for the scalar graphs, Return a nested dictionary for the scalar graphs,
where the first key is the graph title and the second is the series name. where the first key is the graph title and the second is the series name.
@ -971,6 +1051,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return response.response_data return response.response_data
def get_reported_console_output(self, number_of_reports=1): def get_reported_console_output(self, number_of_reports=1):
# type: (int) -> Sequence[str]
""" """
Return a list of console outputs reported by the Task. Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs. Returned console outputs are retrieved from the most updated console outputs.
@ -994,6 +1075,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@classmethod @classmethod
def add_requirements(cls, package_name, package_version=None): def add_requirements(cls, package_name, package_version=None):
# type: (str, Optional[str]) -> ()
""" """
Force package in requirements list. If version is not specified, use the installed package version if found. Force package in requirements list. If version is not specified, use the installed package version if found.
:param str package_name: Package name to add to the "Installed Packages" section of the task :param str package_name: Package name to add to the "Installed Packages" section of the task
@ -1002,11 +1084,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
cls._force_requirements[package_name] = package_version cls._force_requirements[package_name] = package_version
def _get_models(self, model_type='output'): def _get_models(self, model_type='output'):
# type: (Union['output', 'input']) -> Sequence[Model]
model_type = model_type.lower().strip() model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input' assert model_type == 'output' or model_type == 'input'
if model_type == 'input': if model_type == 'input':
regex = '((?i)(Using model id: )(\w+)?)' regex = r'((?i)(Using model id: )(\w+)?)'
compiled = re.compile(regex) compiled = re.compile(regex)
ids = [i[-1] for i in re.findall(compiled, self.comment)] + ( ids = [i[-1] for i in re.findall(compiled, self.comment)] + (
[self.input_model_id] if self.input_model_id else []) [self.input_model_id] if self.input_model_id else [])
@ -1016,11 +1099,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
in_model = [] in_model = []
for i in ids: for i in ids:
m = TrainsModel(model_id=i) m = TrainsModel(model_id=i)
# noinspection PyBroadException
try: try:
# make sure the model is is valid # make sure the model is is valid
# noinspection PyProtectedMember
m._get_model_data() m._get_model_data()
in_model.append(m) in_model.append(m)
except: except Exception:
pass pass
return in_model return in_model
else: else:
@ -1040,11 +1125,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return [TrainsModel(model_id=i) for i in ids] return [TrainsModel(model_id=i) for i in ids]
def _get_default_report_storage_uri(self): def _get_default_report_storage_uri(self):
# type: () -> str
if not self._files_server: if not self._files_server:
self._files_server = Session.get_files_server_host() self._files_server = Session.get_files_server_host()
return self._files_server return self._files_server
def _get_status(self): def _get_status(self):
# type: () -> (Optional[str], Optional[str])
# noinspection PyBroadException
try: try:
all_tasks = self.send( all_tasks = self.send(
tasks.GetAllRequest(id=[self.id], only_fields=['status', 'status_message']), tasks.GetAllRequest(id=[self.id], only_fields=['status', 'status_message']),
@ -1054,6 +1142,8 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return None, None return None, None
def _reload_last_iteration(self): def _reload_last_iteration(self):
# type: () -> ()
# noinspection PyBroadException
try: try:
all_tasks = self.send( all_tasks = self.send(
tasks.GetAllRequest(id=[self.id], only_fields=['last_iteration']), tasks.GetAllRequest(id=[self.id], only_fields=['last_iteration']),
@ -1063,6 +1153,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return None return None
def _clear_task(self, system_tags=None, comment=None): def _clear_task(self, system_tags=None, comment=None):
# type: (Optional[Sequence[str]], Optional[str]) -> ()
self._data.script = tasks.Script( self._data.script = tasks.Script(
binary='', repository='', tag='', branch='', version_num='', entry_point='', binary='', repository='', tag='', branch='', version_num='', entry_point='',
working_dir='', requirements={}, diff='', working_dir='', requirements={}, diff='',
@ -1087,14 +1178,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@classmethod @classmethod
def _get_api_server(cls): def _get_api_server(cls):
# type: () -> ()
return Session.get_api_server_host() return Session.get_api_server_host()
def _get_app_server(self): def _get_app_server(self):
# type: () -> str
if not self._app_server: if not self._app_server:
self._app_server = Session.get_app_server_host() self._app_server = Session.get_app_server_host()
return self._app_server return self._app_server
def _edit(self, **kwargs): def _edit(self, **kwargs):
# type: (**Any) -> Any
with self._edit_lock: with self._edit_lock:
# Since we ae using forced update, make sure he task status is valid # Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status status = self._data.status if self._data and self._reload_skip_flag else self.data.status
@ -1109,9 +1203,11 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
return res return res
def _update_requirements(self, requirements): def _update_requirements(self, requirements):
# type: (Union[dict, str]) -> ()
if not isinstance(requirements, dict): if not isinstance(requirements, dict):
requirements = {'pip': requirements} requirements = {'pip': requirements}
# protection, Old API might not support it # protection, Old API might not support it
# noinspection PyBroadException
try: try:
self.data.script.requirements = requirements self.data.script.requirements = requirements
self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements)) self.send(tasks.SetRequirementsRequest(task=self.id, requirements=requirements))
@ -1119,35 +1215,38 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
pass pass
def _update_script(self, script): def _update_script(self, script):
# type: (dict) -> ()
self.data.script = script self.data.script = script
self._edit(script=script) self._edit(script=script)
@classmethod @classmethod
def _clone_task(cls, cloned_task_id, name=None, comment=None, execution_overrides=None, def _clone_task(
tags=None, parent=None, project=None, log=None, session=None): cls,
cloned_task_id, # type: str
name=None, # type: Optional[str]
comment=None, # type: Optional[str]
execution_overrides=None, # type: Optional[dict]
tags=None, # type: Optional[Sequence[str]]
parent=None, # type: Optional[str]
project=None, # type: Optional[str]
log=None, # type: Optional[logging.Logger]
session=None, # type: Optional[Session]
):
# type: (...) -> str
""" """
Clone a task Clone a task
:param cloned_task_id: Task ID for the task to be cloned :param str cloned_task_id: Task ID for the task to be cloned
:type cloned_task_id: str :param str name: New for the new task
:param name: New for the new task :param str comment: Optional comment for the new task
:type name: str :param dict execution_overrides: Task execution overrides. Applied over the cloned task's execution
:param comment: Optional comment for the new task
:type comment: str
:param execution_overrides: Task execution overrides. Applied over the cloned task's execution
section, useful for overriding values in the cloned task. section, useful for overriding values in the cloned task.
:type execution_overrides: dict :param list tags: Optional updated model tags
:param tags: Optional updated model tags :param str parent: Optional parent Task ID of the new task.
:type tags: [str] :param str project: Optional project ID of the new task.
:param parent: Optional parent Task ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project. If None, the new task will inherit the cloned task's project.
:type project: str :param logging.Logger log: Log object used by the infrastructure.
:param log: Log object used by the infrastructure. :param Session session: Session object used for sending requests to the API
:type log: logging.Logger
:param session: Session object used for sending requests to the API
:type session: Session
:return: The new tasks's ID :return: The new tasks's ID
""" """
@ -1189,14 +1288,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@classmethod @classmethod
def get_all(cls, session=None, log=None, **kwargs): def get_all(cls, session=None, log=None, **kwargs):
# type: (Optional[Session], Optional[logging.Logger], **Any) -> Any
""" """
List all the Tasks based on specific projection. List all the Tasks based on specific projection.
:param session: The session object used for sending requests to the API. :param Session session: The session object used for sending requests to the API.
:type session: Session :param logging.Logger log: The Log object.
:param log: The Log object. :param kwargs: Keyword args passed to the GetAllRequest
:type log: logging.Logger (see :class:`.backend_api.services.v2_5.tasks.GetAllRequest`)
:param kwargs: Keyword args passed to the GetAllRequest (see :class:`.backend_api.services.v2_5.tasks.GetAllRequest`)
For example: For example:
@ -1215,12 +1314,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@classmethod @classmethod
def get_by_name(cls, task_name): def get_by_name(cls, task_name):
# type: (str) -> Task
res = cls._send(cls._get_default_session(), tasks.GetAllRequest(name=exact_match_regex(task_name))) res = cls._send(cls._get_default_session(), tasks.GetAllRequest(name=exact_match_regex(task_name)))
task = get_single_result(entity='task', query=task_name, results=res.response.tasks) task = get_single_result(entity='task', query=task_name, results=res.response.tasks)
return cls(task_id=task.id) return cls(task_id=task.id)
def _get_all_events(self, max_events=100): def _get_all_events(self, max_events=100):
# type: (int) -> Any
""" """
Get a list of all reported events. Get a list of all reported events.
@ -1255,6 +1356,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@property @property
def _edit_lock(self): def _edit_lock(self):
# type: () -> ()
if self.__edit_lock: if self.__edit_lock:
return self.__edit_lock return self.__edit_lock
if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2: if not PROC_MASTER_ID_ENV_VAR.get() or len(PROC_MASTER_ID_ENV_VAR.get().split(':')) < 2:
@ -1262,6 +1364,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id): elif PROC_MASTER_ID_ENV_VAR.get().split(':')[1] == str(self.id):
# remove previous file lock instance, just in case. # remove previous file lock instance, just in case.
filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id)) filename = os.path.join(gettempdir(), 'trains_{}.lock'.format(self.id))
# noinspection PyBroadException
try: try:
os.unlink(filename) os.unlink(filename)
except Exception: except Exception:
@ -1274,22 +1377,26 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@_edit_lock.setter @_edit_lock.setter
def _edit_lock(self, value): def _edit_lock(self, value):
# type: (RLock) -> ()
self.__edit_lock = value self.__edit_lock = value
@classmethod @classmethod
def __update_master_pid_task(cls, pid=None, task=None): def __update_master_pid_task(cls, pid=None, task=None):
# type: (Optional[int], Union[str, Task]) -> ()
pid = pid or os.getpid() pid = pid or os.getpid()
if not task: if not task:
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':') PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':')
elif isinstance(task, str): elif isinstance(task, str):
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + task) PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + task)
else: else:
# noinspection PyUnresolvedReferences
PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id)) PROC_MASTER_ID_ENV_VAR.set(str(pid) + ':' + str(task.id))
# make sure we refresh the edit lock next time we need it, # make sure we refresh the edit lock next time we need it,
task._edit_lock = None task._edit_lock = None
@classmethod @classmethod
def __get_master_id_task_id(cls): def __get_master_id_task_id(cls):
# type: () -> Optional[str]
master_task_id = PROC_MASTER_ID_ENV_VAR.get().split(':') master_task_id = PROC_MASTER_ID_ENV_VAR.get().split(':')
# we could not find a task ID, revert to old stub behaviour # we could not find a task ID, revert to old stub behaviour
if len(master_task_id) < 2 or not master_task_id[1]: if len(master_task_id) < 2 or not master_task_id[1]:
@ -1298,6 +1405,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
@classmethod @classmethod
def __is_subprocess(cls): def __is_subprocess(cls):
# type: () -> bool
# notice this class function is called from Task.ExitHooks, do not rename/move it. # notice this class function is called from Task.ExitHooks, do not rename/move it.
is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \ is_subprocess = PROC_MASTER_ID_ENV_VAR.get() and \
PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid()) PROC_MASTER_ID_ENV_VAR.get().split(':')[0] != str(os.getpid())