This commit is contained in:
revital 2023-07-11 07:41:32 +03:00
commit 999c6543b8
29 changed files with 329 additions and 112 deletions

View File

@ -76,7 +76,7 @@ Instrumenting these components is the **ClearML-server**, see [Self-Hosting](htt
* Full source control info including non-committed local changes
* Execution environment (including specific packages & versions)
* Hyper-parameters
* ArgParser/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
* [`argparse`](https://docs.python.org/3/library/argparse.html)/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
* Explicit parameters dictionary
* Tensorflow Defines (absl-py)
* [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides

View File

@ -1615,7 +1615,7 @@ class PipelineController(object):
def _has_stored_configuration(self):
"""
Return True if we are running remotely and we have stored configuration on the Task
Return True if we are running remotely, and we have stored configuration on the Task
"""
if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task():
stored_config = self._task.get_configuration_object(self._config_section)
@ -1698,9 +1698,10 @@ class PipelineController(object):
conformed_monitors = [
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
]
# verify pair of pairs
# verify the pair of pairs
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
isinstance(x[1][0], str) and isinstance(x[1][1], str) for x in conformed_monitors):
isinstance(x[1][0], str) and isinstance(x[1][1], str)
for x in conformed_monitors):
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
else:
# verify a list of tuples
@ -1711,8 +1712,10 @@ class PipelineController(object):
conformed_monitors = [
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
]
# verify pair of pairs
if not all(isinstance(x[0], str) and isinstance(x[1], str) for x in conformed_monitors):
# verify the pair of pairs
if not all(isinstance(x[0], str) and
isinstance(x[1], str)
for x in conformed_monitors):
raise ValueError(
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
@ -1782,7 +1785,7 @@ class PipelineController(object):
# type: (...) -> bool
"""
Create a Task from a function, including wrapping the function input arguments
into the hyper-parameter section as kwargs, and storing function results as named artifacts
into the hyperparameter section as kwargs, and storing function results as named artifacts
Example:
@ -1874,7 +1877,7 @@ class PipelineController(object):
:param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop
(or marked as failed). Notice, that steps that are connected (or indirectly connected)
to the failed step will be skipped.
:param pre_execute_callback: Callback function, called when the step (Task) is created
:param pre_execute_callback: Callback function, called when the step (Task) is created,
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the ClearmlJob.
@ -2302,7 +2305,8 @@ class PipelineController(object):
:param node_params: list of node parameters
:param visited: list of nodes
:return: Table as List of List of strings (cell)
:return: Table as a List of a List of strings (cell)
"""
task_link_template = self._task.get_output_log_web_page() \
.replace('/{}/'.format(self._task.project), '/{project}/') \
@ -2778,6 +2782,7 @@ class PipelineController(object):
return the pipeline components target folder name/id
:param return_project_id: if False (default), return target folder name. If True, return project id
:return: project id/name (None if not valid)
"""
if not self._target_project:
@ -3110,7 +3115,7 @@ class PipelineDecorator(PipelineController):
:param str target_project: If provided, all pipeline steps are cloned into the target project
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
was specifically defined with "continue_on_fail=True".
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
and mark the pipeline as failed.
@ -3540,7 +3545,7 @@ class PipelineDecorator(PipelineController):
:param _func: wrapper function
:param return_values: Provide a list of names for all the results.
Notice! If not provided no results will be stored as artifacts.
Notice! If not provided, no results will be stored as artifacts.
:param name: Optional, set the name of the pipeline component task.
If not provided, the wrapped function name is used as the pipeline component name
:param cache: If True, before launching the new step,
@ -3623,8 +3628,8 @@ class PipelineDecorator(PipelineController):
# allow up to 5 retries (total of 6 runs)
return retries < 5
:param pre_execute_callback: Callback function, called when the step (Task) is created
and before it is sent for execution. Allows a user to modify the Task before launch.
:param pre_execute_callback: Callback function, called when the step (Task) is created,
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the ClearmlJob.
@ -3644,7 +3649,7 @@ class PipelineDecorator(PipelineController):
pass
:param post_execute_callback: Callback function, called when a step (Task) is completed
and other jobs are executed. Allows a user to modify the Task status after completion.
and other jobs are going to be executed. Allows a user to modify the Task status after completion.
.. code-block:: py
@ -3794,7 +3799,7 @@ class PipelineDecorator(PipelineController):
if target_queue:
PipelineDecorator.set_default_execution_queue(target_queue)
else:
# if we are are not running from a queue, we are probably in debug mode
# if we are not running from a queue, we are probably in debug mode
a_pipeline._clearml_job_class = LocalClearmlJob
a_pipeline._default_execution_queue = 'mock'
@ -3957,7 +3962,7 @@ class PipelineDecorator(PipelineController):
:param str target_project: If provided, all pipeline steps are cloned into the target project
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
was specifically defined with "continue_on_fail=True".
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
and mark the pipeline as failed.

View File

@ -384,9 +384,10 @@ class BaseJob(object):
section_overrides=None,
params_override=None,
configurations_override=None,
explicit_docker_image=None
explicit_docker_image=None,
account_for_artifacts_hashes=True
):
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str]) -> Optional[str]
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str], bool) -> Optional[str]
"""
Create Hash (str) representing the state of the Task
@ -397,6 +398,8 @@ class BaseJob(object):
:param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem)
:param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image
was explicitly changed
:param account_for_artifacts_hashes: Calculate the hash of the task by accounting for the hashes of the
artifacts in `kwargs_artifacts` (as opposed of the task ID/artifact name stored in this section)
:return: str hash of the Task configuration
"""
@ -416,7 +419,23 @@ class BaseJob(object):
# we need to ignore `requirements` section because ir might be changing from run to run
script.pop("requirements", None)
hyper_params = task.get_parameters() if params_override is None else params_override
hyper_params = deepcopy(task.get_parameters() if params_override is None else params_override)
if account_for_artifacts_hashes:
hyper_params_to_change = {}
task_cache = {}
for key, value in hyper_params.items():
if key.startswith("kwargs_artifacts/"):
# noinspection PyBroadException
try:
# key format is <task_id>.<artifact_name>
task_id, artifact = value.split(".", 1)
task_ = task_cache.setdefault(task_id, Task.get_task(task_id))
# set the value of the hyper parameter to the hash of the artifact
# because the task ID might differ, but the artifact might be the same
hyper_params_to_change[key] = task_.artifacts[artifact].hash
except Exception:
pass
hyper_params.update(hyper_params_to_change)
configs = task.get_configuration_objects() if configurations_override is None else configurations_override
# currently we do not add the docker image to the hash (only args and setup script),
# because default docker image will cause the step to change
@ -585,6 +604,14 @@ class ClearmlJob(BaseJob):
if allow_caching:
# look for a cached copy of the Task
# get parameters + task_overrides + as dict and hash it.
task_hash_legacy = self._create_task_hash(
base_temp_task,
section_overrides=sections,
params_override=task_params,
configurations_override=configuration_overrides or None,
explicit_docker_image=kwargs.get("explicit_docker_image"),
account_for_artifacts_hashes=False
)
task_hash = self._create_task_hash(
base_temp_task,
section_overrides=sections,
@ -592,7 +619,7 @@ class ClearmlJob(BaseJob):
configurations_override=configuration_overrides or None,
explicit_docker_image=kwargs.get("explicit_docker_image")
)
task = self._get_cached_task(task_hash)
task = self._get_cached_task(task_hash_legacy) or self._get_cached_task(task_hash)
# if we found a task, just use
if task:
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:

View File

@ -4590,7 +4590,7 @@ class UpdateResponse(Response):
class ValidateDeleteRequest(Request):
"""
Validates that the project existis and can be deleted
Validates that the project exists and can be deleted
:param project: Project ID
:type project: str

View File

@ -4616,7 +4616,7 @@ class UpdateResponse(Response):
class ValidateDeleteRequest(Request):
"""
Validates that the project existis and can be deleted
Validates that the project exists and can be deleted
:param project: Project ID
:type project: str

View File

@ -156,6 +156,19 @@ class Session(TokenManager):
self._connect()
@classmethod
def add_client(cls, client, value, first=True):
# noinspection PyBroadException
try:
if not any(True for c in cls._client if c[0] == client):
if first:
cls._client.insert(0, (client, value))
else:
cls._client.append((client, value))
cls.client = ", ".join("{}-{}".format(*x) for x in cls._client)
except Exception:
pass
def _connect(self):
if self._offline_mode:
return
@ -219,8 +232,7 @@ class Session(TokenManager):
if not api_version:
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
if token_dict.get('server_version'):
if not any(True for c in Session._client if c[0] == 'clearml-server'):
Session._client.append(('clearml-server', token_dict.get('server_version'), ))
self.add_client('clearml-server', token_dict.get('server_version'))
Session.max_api_version = Session.api_version = str(api_version)
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")

View File

@ -644,6 +644,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task
def _reload_field(self, field):
# type: (str) -> Any
""" Reload the task specific field, dot seperated for nesting"""
with self._edit_lock:
if self._offline_mode:
task_object = self._reload()
else:
res = self.send(tasks.GetAllRequest(id=[self.id], only_fields=[field], search_hidden=True))
task_object = res.response.tasks[0]
for p in field.split("."):
task_object = getattr(task_object, p, None)
if task_object is None:
break
return task_object
def reset(self, set_started_on_success=True, force=False):
# type: (bool, bool) -> ()
"""
@ -752,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True):
# type: (bool) -> ()
""" The signal that this task will be published """
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)):
if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse)
@ -792,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
try:
res = self.send(tasks.GetByIdRequest(self.task_id))
task = res.response.task
if task.status == Task.TaskStatusEnum.published:
if task.status == self.TaskStatusEnum.published:
if raise_on_error:
raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
self.log.error("Cannot delete published task {}".format(self.task_id))
@ -2408,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: list of tuples (status, status_message, task_id)
"""
if cls._offline_mode:
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids]
return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
# noinspection PyBroadException
try:
@ -2547,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if not kwargs.pop("force", False) and \
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
# the exception being name/comment that we can always change.
if kwargs and all(
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
@ -3000,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_task_status(cls, task_id):
# type: (str) -> (Optional[str], Optional[str])
if cls._offline_mode:
return tasks.TaskStatusEnum.created, 'offline'
return cls.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException
try:

View File

@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
_checkpoint_filename = {}
__patched = None
__patched_lightning = None
__patched_pytorch_lightning = None
__patched_mmcv = None
@staticmethod
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return
PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io()
PatchPyTorchModelIO._patch_pytorch_lightning_io()
PatchPyTorchModelIO._patch_mmcv()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
@staticmethod
def _patch_model_io():
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if PatchPyTorchModelIO.__patched_lightning:
return
if 'pytorch_lightning' not in sys.modules:
if 'lightning' not in sys.modules:
return
PatchPyTorchModelIO.__patched_lightning = True
# noinspection PyBroadException
try:
import lightning # noqa
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
) # noqa
lightning.pytorch.trainer.Trainer.restore = _patched_call(
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
) # noqa
except ImportError:
pass
except Exception:
pass
# noinspection PyBroadException
try:
import lightning # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save,
) # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj,
) # noqa
except ImportError:
pass
except Exception:
pass
@staticmethod
def _patch_pytorch_lightning_io():
if PatchPyTorchModelIO.__patched_pytorch_lightning:
return
if 'pytorch_lightning' not in sys.modules:
return
PatchPyTorchModelIO.__patched_pytorch_lightning = True
# noinspection PyBroadException
try:
import pytorch_lightning # noqa

View File

@ -3,6 +3,7 @@ from logging import getLogger
from .frameworks import _patched_call # noqa
from .import_bind import PostImportHookPatching
from ..utilities.networking import get_private_ip
from ..config import running_remotely
class PatchGradio:
@ -11,7 +12,7 @@ class PatchGradio:
_default_gradio_address = "0.0.0.0"
_default_gradio_port = 7860
_root_path_format = "/service/{}"
_root_path_format = "/service/{}/"
__server_config_warning = set()
@classmethod
@ -32,42 +33,55 @@ class PatchGradio:
try:
import gradio
gradio.networking.start_server = _patched_call(
gradio.networking.start_server, PatchGradio._patched_start_server
)
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
gradio.routes.App.get_blocks = _patched_call(gradio.routes.App.get_blocks, PatchGradio._patched_get_blocks)
gradio.blocks.Blocks.launch = _patched_call(gradio.blocks.Blocks.launch, PatchGradio._patched_launch)
except Exception:
pass
cls.__patched = True
@staticmethod
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
def _patched_get_blocks(original_fn, *args, **kwargs):
blocks = original_fn(*args, **kwargs)
if not PatchGradio._current_task or not running_remotely():
return blocks
blocks.config["root"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
blocks.root = blocks.config["root"]
return blocks
@staticmethod
def _patched_launch(original_fn, *args, **kwargs):
if not PatchGradio._current_task:
return original_fn(self, server_name, server_port, *args, **kwargs)
return original_fn(*args, **kwargs)
PatchGradio.__warn_on_server_config(
kwargs.get("server_name"),
kwargs.get("server_port"),
kwargs.get("root_path")
)
if not running_remotely():
return original_fn(*args, **kwargs)
# noinspection PyProtectedMember
PatchGradio._current_task._set_runtime_properties(
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
)
PatchGradio._current_task.set_system_tags(["external_service"])
PatchGradio.__warn_on_server_config(server_name, server_port)
server_name = PatchGradio._default_gradio_address
server_port = PatchGradio._default_gradio_port
return original_fn(self, server_name, server_port, *args, **kwargs)
@staticmethod
def _patched_init(original_fn, *args, **kwargs):
if not PatchGradio._current_task:
return original_fn(*args, **kwargs)
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
kwargs["root_path_in_servers"] = False
kwargs["server_name"] = PatchGradio._default_gradio_address
kwargs["server_port"] = PatchGradio._default_gradio_port
return original_fn(*args, **kwargs)
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
# noinspection PyBroadException
try:
return original_fn(*args, **kwargs)
except Exception as e:
del kwargs["root_path"]
return original_fn(*args, **kwargs)
@classmethod
def __warn_on_server_config(cls, server_name, server_port):
if server_name is None and server_port is None:
def __warn_on_server_config(cls, server_name, server_port, root_path):
if (server_name is None or server_name == PatchGradio._default_gradio_address) and \
(server_port is None and server_port == PatchGradio._default_gradio_port):
return
if (server_name, server_port, root_path) in cls.__server_config_warning:
return
cls.__server_config_warning.add((server_name, server_port, root_path))
if server_name is not None and server_port is not None:
server_config = "{}:{}".format(server_name, server_port)
what_to_ignore = "name and port"
@ -77,11 +91,14 @@ class PatchGradio:
else:
server_config = str(server_port)
what_to_ignore = "port"
if server_config in cls.__server_config_warning:
return
cls.__server_config_warning.add(server_config)
getLogger().warning(
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
"ClearML only supports '{}:{}' as the Gradio server. Ignoring {} '{}' in remote execution".format(
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
)
)
if root_path is not None:
getLogger().warning(
"ClearML will override root_path '{}' to '{}' in remote execution".format(
root_path, PatchGradio._root_path_format.format(PatchGradio._current_task.id)
)
)

View File

@ -209,7 +209,10 @@ class PatchJsonArgParse(object):
with change_to_path_dir(path):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {subcommand + PatchJsonArgParse._commands_sep + k: v for k, v in parsed_cfg.items()}
parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg
@staticmethod

View File

@ -237,6 +237,12 @@ def parse_known_host(parsed_host):
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.port is None:
print('Web app hosted on standard port using ' + parsed_host.scheme + ' protocol.')
print('Assuming files and api ports are unchanged and use the same (' + parsed_host.scheme + ') protocol')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc+ ':8081' + parsed_host.path
else:
print("Warning! Could not parse host name")
api_host = ''

View File

@ -7,7 +7,11 @@ from typing import Sequence
from pathlib2 import Path
import clearml.backend_api.session
from clearml.datasets import Dataset
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-data", __version__)
def check_null_id(args):

View File

@ -1,7 +1,9 @@
import sys
import json
import sys
from argparse import ArgumentParser, RawTextHelpFormatter
import clearml.backend_api.session
from clearml import Task
from clearml.automation import (
DiscreteParameterRange,
UniformIntegerParameterRange,
@ -11,8 +13,11 @@ from clearml.automation import (
RandomSearch,
GridSearch,
)
from clearml import Task
from clearml.backend_interface.task.populate import CreateAndPopulate
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-param-search", __version__)
try:
from clearml.automation.optuna import OptimizerOptuna # noqa

View File

@ -3,9 +3,12 @@ from argparse import ArgumentParser
from pathlib2 import Path
import clearml.backend_api.session
from clearml import Task
from clearml.version import __version__
from clearml.backend_interface.task.populate import CreateAndPopulate
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-task", __version__)
def setup_parser(parser):

View File

@ -122,7 +122,7 @@ class Dataset(object):
__hyperparams_section = "Datasets"
__datasets_runtime_prop = "datasets"
__orig_datasets_runtime_prop_prefix = "orig_datasets"
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
__preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int)
__preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int)
__preview_media_image_count = deferred_config("dataset.preview.media.image_count", 10, transform=int)
@ -1877,6 +1877,7 @@ class Dataset(object):
ids=None, # type: Optional[Sequence[str]]
only_completed=True, # type: bool
recursive_project_search=True, # type: bool
include_archived=True, # type: bool
):
# type: (...) -> List[dict]
"""
@ -1890,9 +1891,16 @@ class Dataset(object):
:param recursive_project_search: If True and the `dataset_project` argument is set,
search inside subprojects as well.
If False, don't search inside subprojects (except for the special `.datasets` subproject)
:param include_archived: If True, include archived datasets as well.
:return: List of dictionaries with dataset information
Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},]
"""
# if include_archived is False, we need to add the system tag __$not:archived to filter out archived datasets
if not include_archived:
system_tags = ["__$all", cls.__tag, "__$not", "archived"]
else:
system_tags = [cls.__tag]
if dataset_project:
if not recursive_project_search:
dataset_projects = [
@ -1903,12 +1911,13 @@ class Dataset(object):
dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))]
else:
dataset_projects = None
# noinspection PyProtectedMember
datasets = Task._query_tasks(
task_ids=ids or None,
project_name=dataset_projects,
task_name=partial_name,
system_tags=[cls.__tag],
system_tags=system_tags,
type=[str(Task.TaskTypes.data_processing)],
tags=tags or None,
status=["stopped", "published", "completed", "closed"] if only_completed else None,
@ -2278,14 +2287,15 @@ class Dataset(object):
ds = Dataset.get(dependency)
links.update(ds._dataset_link_entries)
links.update(self._dataset_link_entries)
def _download_link(link,target_path):
def _download_link(link, target_path):
if os.path.exists(target_path):
LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format(
target_path, link
)
LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format(
target_path, link
)
return
)
return
ok = False
error = None
try:
@ -2310,16 +2320,12 @@ class Dataset(object):
if not max_workers:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
_download_link(link,target_path)
_download_link(link, target_path)
else:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path)
pool.submit(_download_link,link,target_path)
pool.submit(_download_link, link, target_path)
def _extract_dataset_archive(
self,
@ -2732,6 +2738,7 @@ class Dataset(object):
)
)
)
def _build_dependency_chunk_lookup(self):
# type: () -> Dict[str, int]
"""

View File

@ -199,6 +199,10 @@ class Logger(object):
"""
For explicit reporting, plot a vector as (default stacked) histogram.
.. note::
This method is the same as :meth:`Logger.report_histogram`.
This method is deprecated, use :meth:`Logger.report_histogram` instead.
For example:
.. code-block:: py
@ -224,6 +228,11 @@ class Logger(object):
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
"""
warnings.warn(
":meth:`Logger.report_vector` is deprecated;"
"use :meth:`Logger.report_histogram` instead.",
DeprecationWarning
)
self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
@ -439,7 +448,16 @@ class Logger(object):
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
.. note::
This method is the same as :meth:`Logger.report_scatter2d` with :param:`mode='lines'`.
This method is deprecated, use :meth:`Logger.report_scatter2d` instead.
"""
warnings.warn(
":meth:`Logger.report_line_plot` is deprecated;"
"use :meth:`Logger.report_scatter2d` instead, e.g., with :param:`mode='lines'`.",
DeprecationWarning
)
# noinspection PyArgumentList
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
@ -710,6 +728,7 @@ class Logger(object):
.. note::
This method is the same as :meth:`Logger.report_confusion_matrix`.
This method is deprecated, use :meth:`Logger.report_confusion_matrix` instead.
:param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix.
@ -719,11 +738,16 @@ class Logger(object):
:param str yaxis: The y-axis title. (Optional)
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
:param bool yaxis_reversed: If False, 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner
:param bool yaxis_reversed: If False, 0,0 is in the bottom left corner. If True, 0,0 is in the top left corner
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
"""
warnings.warn(
":meth:`Logger.report_matrix` is deprecated;"
"use :meth:`Logger.report_confusion_matrix` instead.",
DeprecationWarning
)
self._touch_title_series(title, series)
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,

View File

@ -1783,7 +1783,7 @@ class StorageHelper(object):
Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3)
"""
_temp_download_suffix = '.partially'
_quotable_uri_schemes = set(_HttpDriver.schemes) | set([_GoogleCloudStorageDriver.scheme])
_quotable_uri_schemes = set(_HttpDriver.schemes)
@classmethod
def _get_logger(cls):
@ -2414,7 +2414,7 @@ class StorageHelper(object):
if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema
# quote link
result_dest_path = quote_url(result_dest_path)
result_dest_path = quote_url(result_dest_path, StorageHelper._quotable_uri_schemes)
return result_dest_path
@ -2436,7 +2436,7 @@ class StorageHelper(object):
# quote link
def callback(result):
return a_cb(quote_url(result_path) if result else result)
return a_cb(quote_url(result_path, StorageHelper._quotable_uri_schemes) if result else result)
# replace callback with wrapper
cb = callback
@ -2463,7 +2463,7 @@ class StorageHelper(object):
retries=retries,
return_canonized=return_canonized)
if res:
result_path = quote_url(result_path)
result_path = quote_url(result_path, StorageHelper._quotable_uri_schemes)
return result_path
def list(self, prefix=None, with_metadata=False):

View File

@ -42,9 +42,9 @@ def get_config_object_matcher(**patterns):
return _matcher
def quote_url(url):
def quote_url(url, valid_schemes=("http", "https")):
parsed = urlparse(url)
if parsed.scheme not in ("http", "https", "gs"):
if parsed.scheme not in valid_schemes:
return url
parsed = parsed._replace(path=quote(parsed.path))
return urlunparse(parsed)

View File

@ -1009,10 +1009,16 @@ class Task(_Task):
If None is passed, returns all tasks within the project
:param list tags: Filter based on the requested list of tags (strings)
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
To include All tags (instead of the default All behaviour) use "__$all" as the first string, example:
The default behaviour is to join all tags with a logical "OR" operator.
To join all tags with a logical "AND" operator instead, use "__$all" as the first string, for example:
["__$all", "best", "experiment", "ever"]
To combine All tags and exclude a list of tags use "__$not" before the excluded tags, example:
["__$all", "best", "experiment", "ever", "__$not", "internal", "test"]
To join all tags with AND, but exclude a tag use "__$not" before the excluded tag, for example:
["__$all", "best", "experiment", "ever", "__$not", "internal", "__$not", "test"]
The "OR" and "AND" operators apply to all tags that follow them until another operator is specified.
The NOT operator applies only to the immediately following tag.
For example, ["__$all", "a", "b", "c", "__$or", "d", "__$not", "e", "__$and", "__$or" "f", "g"]
means ("a" AND "b" AND "c" AND ("d" OR NOT "e") AND ("f" OR "g")).
See https://clear.ml/docs/latest/docs/clearml_sdk/task_sdk/#tag-filters for more information.
:param list additional_return_fields: Optional, if not provided return a list of Task IDs.
If provided return dict per Task with the additional requested fields.
Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict:
@ -3449,8 +3455,8 @@ class Task(_Task):
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \
if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
if ((task._status in (
cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
or task.output_models_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags)
or task_artifacts):
@ -3621,7 +3627,10 @@ class Task(_Task):
# at least until we support multiple input models
# notice that we do not check the task's input model because we allow task reuse and overwrite
# add into comment that we are using this model
comment = self.comment or ''
# refresh comment
comment = self._reload_field("comment") or self.comment or ''
if not comment.endswith('\n'):
comment += '\n'
comment += 'Using model id: {}'.format(model.id)
@ -3673,14 +3682,14 @@ class Task(_Task):
# noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
def _refresh_args_dict(task, config_dict):
def _refresh_args_dict(task, config_proxy_dict):
# reread from task including newly added keys
# noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_proxy_dict), prefix=name)
# noinspection PyProtectedMember
nested_dict = config_dict._to_dict()
config_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
nested_dict = config_proxy_dict._to_dict()
config_proxy_dict.clear()
config_proxy_dict._do_update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
def _check_keys(dict_, warning_sent=False):
if warning_sent:
@ -4606,15 +4615,15 @@ class Task(_Task):
return False
stopped_statuses = (
str(tasks.TaskStatusEnum.stopped),
str(tasks.TaskStatusEnum.published),
str(tasks.TaskStatusEnum.publishing),
str(tasks.TaskStatusEnum.closed),
str(tasks.TaskStatusEnum.failed),
str(tasks.TaskStatusEnum.completed),
cls.TaskStatusEnum.stopped,
cls.TaskStatusEnum.published,
cls.TaskStatusEnum.publishing,
cls.TaskStatusEnum.closed,
cls.TaskStatusEnum.failed,
cls.TaskStatusEnum.completed,
)
if str(task.status) not in stopped_statuses:
if task.status not in stopped_statuses:
cls._send(
cls._get_default_session(),
tasks.StoppedRequest(

View File

@ -61,6 +61,23 @@ def config_dict_to_text(config):
try:
# noinspection PyBroadException
try:
def raise_on_special_key(config_):
if not isinstance(config_, dict):
return
special_chars = "$}[]:=+#`^?!@*&."
for key in config_.keys():
if not isinstance(key, str):
continue
if any(key_char in special_chars for key_char in key):
raise ValueError(
"Configuration dictionary keys cannot contain any of the following characters: {}".format(
special_chars
)
)
for val in config_.values():
raise_on_special_key(val)
# will fall back to json+pyhocon
raise_on_special_key(config)
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config)))
except Exception:
# fallback json+pyhocon

View File

@ -39,10 +39,14 @@ class ProxyDictPostWrite(dict):
return a_dict
def update(self, E=None, **F):
res = self._do_update(E, **F)
self._set_callback()
return res
def _do_update(self, E=None, **F):
res = super(ProxyDictPostWrite, self).update(
ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
self._set_callback()
return res

View File

@ -1 +1 @@
__version__ = '1.11.1rc1'
__version__ = '1.11.2rc0'

View File

@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
"""
from __future__ import print_function
import argparse
import os
from tempfile import gettempdir
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from clearml import Task, Logger
@ -51,7 +54,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader, epoch):
@ -127,8 +130,9 @@ def main():
task.execute_remotely(queue_name="default")
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch)
if (args.save_model):
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
if __name__ == '__main__':

View File

@ -8,6 +8,6 @@ trainer:
filename: best
save_last: False
save_top_k: 1
monitor: loss
monitor: train_loss
mode: min
max_epochs: 10

View File

@ -78,6 +78,7 @@ class ImageClassifier(LightningModule):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y.long())
self.log("train_loss", loss)
return loss
def test_step(self, batch, batch_idx):

View File

@ -8,5 +8,6 @@ trainer:
filename: best
save_last: False
save_top_k: 1
monitor: loss
monitor: train_loss
mode: min
max_epochs: 3

View File

@ -156,7 +156,7 @@ def main(_):
test(FLAGS, model, device, test_loader, epoch)
if FLAGS.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
# ClearML - Example of Pytorch mnist training integration
#
from __future__ import print_function
import argparse
import os
from tempfile import gettempdir
@ -47,7 +48,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader, epoch):
@ -128,7 +129,7 @@ def main():
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch)
if (args.save_model):
if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))

View File

@ -7,3 +7,4 @@ tensorboardX
torch>=1.1.0
torchvision>=0.3.0
tqdm
protobuf>=4.21.1