This commit is contained in:
revital 2023-07-11 07:41:32 +03:00
commit 999c6543b8
29 changed files with 329 additions and 112 deletions

View File

@ -76,7 +76,7 @@ Instrumenting these components is the **ClearML-server**, see [Self-Hosting](htt
* Full source control info including non-committed local changes * Full source control info including non-committed local changes
* Execution environment (including specific packages & versions) * Execution environment (including specific packages & versions)
* Hyper-parameters * Hyper-parameters
* ArgParser/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values * [`argparse`](https://docs.python.org/3/library/argparse.html)/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
* Explicit parameters dictionary * Explicit parameters dictionary
* Tensorflow Defines (absl-py) * Tensorflow Defines (absl-py)
* [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides * [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides

View File

@ -1615,7 +1615,7 @@ class PipelineController(object):
def _has_stored_configuration(self): def _has_stored_configuration(self):
""" """
Return True if we are running remotely and we have stored configuration on the Task Return True if we are running remotely, and we have stored configuration on the Task
""" """
if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task(): if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task():
stored_config = self._task.get_configuration_object(self._config_section) stored_config = self._task.get_configuration_object(self._config_section)
@ -1698,9 +1698,10 @@ class PipelineController(object):
conformed_monitors = [ conformed_monitors = [
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
] ]
# verify pair of pairs # verify the pair of pairs
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
isinstance(x[1][0], str) and isinstance(x[1][1], str) for x in conformed_monitors): isinstance(x[1][0], str) and isinstance(x[1][1], str)
for x in conformed_monitors):
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors)) raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
else: else:
# verify a list of tuples # verify a list of tuples
@ -1711,8 +1712,10 @@ class PipelineController(object):
conformed_monitors = [ conformed_monitors = [
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
] ]
# verify pair of pairs # verify the pair of pairs
if not all(isinstance(x[0], str) and isinstance(x[1], str) for x in conformed_monitors): if not all(isinstance(x[0], str) and
isinstance(x[1], str)
for x in conformed_monitors):
raise ValueError( raise ValueError(
"{} should be a list of tuples, found: {}".format(monitor_type, monitors)) "{} should be a list of tuples, found: {}".format(monitor_type, monitors))
@ -1782,7 +1785,7 @@ class PipelineController(object):
# type: (...) -> bool # type: (...) -> bool
""" """
Create a Task from a function, including wrapping the function input arguments Create a Task from a function, including wrapping the function input arguments
into the hyper-parameter section as kwargs, and storing function results as named artifacts into the hyperparameter section as kwargs, and storing function results as named artifacts
Example: Example:
@ -1874,7 +1877,7 @@ class PipelineController(object):
:param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop :param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop
(or marked as failed). Notice, that steps that are connected (or indirectly connected) (or marked as failed). Notice, that steps that are connected (or indirectly connected)
to the failed step will be skipped. to the failed step will be skipped.
:param pre_execute_callback: Callback function, called when the step (Task) is created :param pre_execute_callback: Callback function, called when the step (Task) is created,
and before it is sent for execution. Allows a user to modify the Task before launch. and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the ClearmlJob. `parameters` are the configuration arguments passed to the ClearmlJob.
@ -2302,7 +2305,8 @@ class PipelineController(object):
:param node_params: list of node parameters :param node_params: list of node parameters
:param visited: list of nodes :param visited: list of nodes
:return: Table as List of List of strings (cell)
:return: Table as a List of a List of strings (cell)
""" """
task_link_template = self._task.get_output_log_web_page() \ task_link_template = self._task.get_output_log_web_page() \
.replace('/{}/'.format(self._task.project), '/{project}/') \ .replace('/{}/'.format(self._task.project), '/{project}/') \
@ -2778,6 +2782,7 @@ class PipelineController(object):
return the pipeline components target folder name/id return the pipeline components target folder name/id
:param return_project_id: if False (default), return target folder name. If True, return project id :param return_project_id: if False (default), return target folder name. If True, return project id
:return: project id/name (None if not valid) :return: project id/name (None if not valid)
""" """
if not self._target_project: if not self._target_project:
@ -3110,7 +3115,7 @@ class PipelineDecorator(PipelineController):
:param str target_project: If provided, all pipeline steps are cloned into the target project :param str target_project: If provided, all pipeline steps are cloned into the target project
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline :param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step, to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
was specifically defined with "continue_on_fail=True". was specifically defined with "continue_on_fail=True".
If True, any failed step will cause the pipeline to immediately abort, stop all running steps, If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
and mark the pipeline as failed. and mark the pipeline as failed.
@ -3540,7 +3545,7 @@ class PipelineDecorator(PipelineController):
:param _func: wrapper function :param _func: wrapper function
:param return_values: Provide a list of names for all the results. :param return_values: Provide a list of names for all the results.
Notice! If not provided no results will be stored as artifacts. Notice! If not provided, no results will be stored as artifacts.
:param name: Optional, set the name of the pipeline component task. :param name: Optional, set the name of the pipeline component task.
If not provided, the wrapped function name is used as the pipeline component name If not provided, the wrapped function name is used as the pipeline component name
:param cache: If True, before launching the new step, :param cache: If True, before launching the new step,
@ -3623,7 +3628,7 @@ class PipelineDecorator(PipelineController):
# allow up to 5 retries (total of 6 runs) # allow up to 5 retries (total of 6 runs)
return retries < 5 return retries < 5
:param pre_execute_callback: Callback function, called when the step (Task) is created :param pre_execute_callback: Callback function, called when the step (Task) is created,
and before it is sent for execution. Allows a user to modify the Task before launch. and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object. Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the ClearmlJob. `parameters` are the configuration arguments passed to the ClearmlJob.
@ -3644,7 +3649,7 @@ class PipelineDecorator(PipelineController):
pass pass
:param post_execute_callback: Callback function, called when a step (Task) is completed :param post_execute_callback: Callback function, called when a step (Task) is completed
and other jobs are executed. Allows a user to modify the Task status after completion. and other jobs are going to be executed. Allows a user to modify the Task status after completion.
.. code-block:: py .. code-block:: py
@ -3794,7 +3799,7 @@ class PipelineDecorator(PipelineController):
if target_queue: if target_queue:
PipelineDecorator.set_default_execution_queue(target_queue) PipelineDecorator.set_default_execution_queue(target_queue)
else: else:
# if we are are not running from a queue, we are probably in debug mode # if we are not running from a queue, we are probably in debug mode
a_pipeline._clearml_job_class = LocalClearmlJob a_pipeline._clearml_job_class = LocalClearmlJob
a_pipeline._default_execution_queue = 'mock' a_pipeline._default_execution_queue = 'mock'
@ -3957,7 +3962,7 @@ class PipelineDecorator(PipelineController):
:param str target_project: If provided, all pipeline steps are cloned into the target project :param str target_project: If provided, all pipeline steps are cloned into the target project
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline :param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step, to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
was specifically defined with "continue_on_fail=True". was specifically defined with "continue_on_fail=True".
If True, any failed step will cause the pipeline to immediately abort, stop all running steps, If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
and mark the pipeline as failed. and mark the pipeline as failed.

View File

@ -384,9 +384,10 @@ class BaseJob(object):
section_overrides=None, section_overrides=None,
params_override=None, params_override=None,
configurations_override=None, configurations_override=None,
explicit_docker_image=None explicit_docker_image=None,
account_for_artifacts_hashes=True
): ):
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str]) -> Optional[str] # type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str], bool) -> Optional[str]
""" """
Create Hash (str) representing the state of the Task Create Hash (str) representing the state of the Task
@ -397,6 +398,8 @@ class BaseJob(object):
:param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem) :param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem)
:param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image :param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image
was explicitly changed was explicitly changed
:param account_for_artifacts_hashes: Calculate the hash of the task by accounting for the hashes of the
artifacts in `kwargs_artifacts` (as opposed of the task ID/artifact name stored in this section)
:return: str hash of the Task configuration :return: str hash of the Task configuration
""" """
@ -416,7 +419,23 @@ class BaseJob(object):
# we need to ignore `requirements` section because ir might be changing from run to run # we need to ignore `requirements` section because ir might be changing from run to run
script.pop("requirements", None) script.pop("requirements", None)
hyper_params = task.get_parameters() if params_override is None else params_override hyper_params = deepcopy(task.get_parameters() if params_override is None else params_override)
if account_for_artifacts_hashes:
hyper_params_to_change = {}
task_cache = {}
for key, value in hyper_params.items():
if key.startswith("kwargs_artifacts/"):
# noinspection PyBroadException
try:
# key format is <task_id>.<artifact_name>
task_id, artifact = value.split(".", 1)
task_ = task_cache.setdefault(task_id, Task.get_task(task_id))
# set the value of the hyper parameter to the hash of the artifact
# because the task ID might differ, but the artifact might be the same
hyper_params_to_change[key] = task_.artifacts[artifact].hash
except Exception:
pass
hyper_params.update(hyper_params_to_change)
configs = task.get_configuration_objects() if configurations_override is None else configurations_override configs = task.get_configuration_objects() if configurations_override is None else configurations_override
# currently we do not add the docker image to the hash (only args and setup script), # currently we do not add the docker image to the hash (only args and setup script),
# because default docker image will cause the step to change # because default docker image will cause the step to change
@ -585,6 +604,14 @@ class ClearmlJob(BaseJob):
if allow_caching: if allow_caching:
# look for a cached copy of the Task # look for a cached copy of the Task
# get parameters + task_overrides + as dict and hash it. # get parameters + task_overrides + as dict and hash it.
task_hash_legacy = self._create_task_hash(
base_temp_task,
section_overrides=sections,
params_override=task_params,
configurations_override=configuration_overrides or None,
explicit_docker_image=kwargs.get("explicit_docker_image"),
account_for_artifacts_hashes=False
)
task_hash = self._create_task_hash( task_hash = self._create_task_hash(
base_temp_task, base_temp_task,
section_overrides=sections, section_overrides=sections,
@ -592,7 +619,7 @@ class ClearmlJob(BaseJob):
configurations_override=configuration_overrides or None, configurations_override=configuration_overrides or None,
explicit_docker_image=kwargs.get("explicit_docker_image") explicit_docker_image=kwargs.get("explicit_docker_image")
) )
task = self._get_cached_task(task_hash) task = self._get_cached_task(task_hash_legacy) or self._get_cached_task(task_hash)
# if we found a task, just use # if we found a task, just use
if task: if task:
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created: if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:

View File

@ -4590,7 +4590,7 @@ class UpdateResponse(Response):
class ValidateDeleteRequest(Request): class ValidateDeleteRequest(Request):
""" """
Validates that the project existis and can be deleted Validates that the project exists and can be deleted
:param project: Project ID :param project: Project ID
:type project: str :type project: str

View File

@ -4616,7 +4616,7 @@ class UpdateResponse(Response):
class ValidateDeleteRequest(Request): class ValidateDeleteRequest(Request):
""" """
Validates that the project existis and can be deleted Validates that the project exists and can be deleted
:param project: Project ID :param project: Project ID
:type project: str :type project: str

View File

@ -156,6 +156,19 @@ class Session(TokenManager):
self._connect() self._connect()
@classmethod
def add_client(cls, client, value, first=True):
# noinspection PyBroadException
try:
if not any(True for c in cls._client if c[0] == client):
if first:
cls._client.insert(0, (client, value))
else:
cls._client.append((client, value))
cls.client = ", ".join("{}-{}".format(*x) for x in cls._client)
except Exception:
pass
def _connect(self): def _connect(self):
if self._offline_mode: if self._offline_mode:
return return
@ -219,8 +232,7 @@ class Session(TokenManager):
if not api_version: if not api_version:
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
if token_dict.get('server_version'): if token_dict.get('server_version'):
if not any(True for c in Session._client if c[0] == 'clearml-server'): self.add_client('clearml-server', token_dict.get('server_version'))
Session._client.append(('clearml-server', token_dict.get('server_version'), ))
Session.max_api_version = Session.api_version = str(api_version) Session.max_api_version = Session.api_version = str(api_version)
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic") Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")

View File

@ -644,6 +644,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
res = self.send(tasks.GetByIdRequest(task=self.id)) res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task return res.response.task
def _reload_field(self, field):
# type: (str) -> Any
""" Reload the task specific field, dot seperated for nesting"""
with self._edit_lock:
if self._offline_mode:
task_object = self._reload()
else:
res = self.send(tasks.GetAllRequest(id=[self.id], only_fields=[field], search_hidden=True))
task_object = res.response.tasks[0]
for p in field.split("."):
task_object = getattr(task_object, p, None)
if task_object is None:
break
return task_object
def reset(self, set_started_on_success=True, force=False): def reset(self, set_started_on_success=True, force=False):
# type: (bool, bool) -> () # type: (bool, bool) -> ()
""" """
@ -752,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def publish(self, ignore_errors=True): def publish(self, ignore_errors=True):
# type: (bool) -> () # type: (bool) -> ()
""" The signal that this task will be published """ """ The signal that this task will be published """
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)): if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
raise ValueError("Can't publish, Task is not stopped") raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors) resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse) assert isinstance(resp.response, tasks.PublishResponse)
@ -792,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
try: try:
res = self.send(tasks.GetByIdRequest(self.task_id)) res = self.send(tasks.GetByIdRequest(self.task_id))
task = res.response.task task = res.response.task
if task.status == Task.TaskStatusEnum.published: if task.status == self.TaskStatusEnum.published:
if raise_on_error: if raise_on_error:
raise self.DeleteError("Cannot delete published task {}".format(self.task_id)) raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
self.log.error("Cannot delete published task {}".format(self.task_id)) self.log.error("Cannot delete published task {}".format(self.task_id))
@ -2408,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:return: list of tuples (status, status_message, task_id) :return: list of tuples (status, status_message, task_id)
""" """
if cls._offline_mode: if cls._offline_mode:
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids] return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -2547,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# Since we ae using forced update, make sure he task status is valid # Since we ae using forced update, make sure he task status is valid
status = self._data.status if self._data and self._reload_skip_flag else self.data.status status = self._data.status if self._data and self._reload_skip_flag else self.data.status
if not kwargs.pop("force", False) and \ if not kwargs.pop("force", False) and \
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress): status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
# the exception being name/comment that we can always change. # the exception being name/comment that we can always change.
if kwargs and all( if kwargs and all(
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys() k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
@ -3000,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def _get_task_status(cls, task_id): def _get_task_status(cls, task_id):
# type: (str) -> (Optional[str], Optional[str]) # type: (str) -> (Optional[str], Optional[str])
if cls._offline_mode: if cls._offline_mode:
return tasks.TaskStatusEnum.created, 'offline' return cls.TaskStatusEnum.created, 'offline'
# noinspection PyBroadException # noinspection PyBroadException
try: try:

View File

@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
_checkpoint_filename = {} _checkpoint_filename = {}
__patched = None __patched = None
__patched_lightning = None __patched_lightning = None
__patched_pytorch_lightning = None
__patched_mmcv = None __patched_mmcv = None
@staticmethod @staticmethod
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
return return
PatchPyTorchModelIO._patch_model_io() PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io() PatchPyTorchModelIO._patch_lightning_io()
PatchPyTorchModelIO._patch_pytorch_lightning_io()
PatchPyTorchModelIO._patch_mmcv() PatchPyTorchModelIO._patch_mmcv()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io) PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
@staticmethod @staticmethod
def _patch_model_io(): def _patch_model_io():
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
if PatchPyTorchModelIO.__patched_lightning: if PatchPyTorchModelIO.__patched_lightning:
return return
if 'pytorch_lightning' not in sys.modules: if 'lightning' not in sys.modules:
return return
PatchPyTorchModelIO.__patched_lightning = True PatchPyTorchModelIO.__patched_lightning = True
# noinspection PyBroadException
try:
import lightning # noqa
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
) # noqa
lightning.pytorch.trainer.Trainer.restore = _patched_call(
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
) # noqa
except ImportError:
pass
except Exception:
pass
# noinspection PyBroadException
try:
import lightning # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save,
) # noqa
# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj,
) # noqa
except ImportError:
pass
except Exception:
pass
@staticmethod
def _patch_pytorch_lightning_io():
if PatchPyTorchModelIO.__patched_pytorch_lightning:
return
if 'pytorch_lightning' not in sys.modules:
return
PatchPyTorchModelIO.__patched_pytorch_lightning = True
# noinspection PyBroadException # noinspection PyBroadException
try: try:
import pytorch_lightning # noqa import pytorch_lightning # noqa

View File

@ -3,6 +3,7 @@ from logging import getLogger
from .frameworks import _patched_call # noqa from .frameworks import _patched_call # noqa
from .import_bind import PostImportHookPatching from .import_bind import PostImportHookPatching
from ..utilities.networking import get_private_ip from ..utilities.networking import get_private_ip
from ..config import running_remotely
class PatchGradio: class PatchGradio:
@ -11,7 +12,7 @@ class PatchGradio:
_default_gradio_address = "0.0.0.0" _default_gradio_address = "0.0.0.0"
_default_gradio_port = 7860 _default_gradio_port = 7860
_root_path_format = "/service/{}" _root_path_format = "/service/{}/"
__server_config_warning = set() __server_config_warning = set()
@classmethod @classmethod
@ -32,42 +33,55 @@ class PatchGradio:
try: try:
import gradio import gradio
gradio.networking.start_server = _patched_call( gradio.routes.App.get_blocks = _patched_call(gradio.routes.App.get_blocks, PatchGradio._patched_get_blocks)
gradio.networking.start_server, PatchGradio._patched_start_server gradio.blocks.Blocks.launch = _patched_call(gradio.blocks.Blocks.launch, PatchGradio._patched_launch)
)
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
except Exception: except Exception:
pass pass
cls.__patched = True cls.__patched = True
@staticmethod @staticmethod
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs): def _patched_get_blocks(original_fn, *args, **kwargs):
blocks = original_fn(*args, **kwargs)
if not PatchGradio._current_task or not running_remotely():
return blocks
blocks.config["root"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
blocks.root = blocks.config["root"]
return blocks
@staticmethod
def _patched_launch(original_fn, *args, **kwargs):
if not PatchGradio._current_task: if not PatchGradio._current_task:
return original_fn(self, server_name, server_port, *args, **kwargs) return original_fn(*args, **kwargs)
PatchGradio.__warn_on_server_config(
kwargs.get("server_name"),
kwargs.get("server_port"),
kwargs.get("root_path")
)
if not running_remotely():
return original_fn(*args, **kwargs)
# noinspection PyProtectedMember
PatchGradio._current_task._set_runtime_properties( PatchGradio._current_task._set_runtime_properties(
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port} {"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
) )
PatchGradio._current_task.set_system_tags(["external_service"]) PatchGradio._current_task.set_system_tags(["external_service"])
PatchGradio.__warn_on_server_config(server_name, server_port)
server_name = PatchGradio._default_gradio_address
server_port = PatchGradio._default_gradio_port
return original_fn(self, server_name, server_port, *args, **kwargs)
@staticmethod
def _patched_init(original_fn, *args, **kwargs):
if not PatchGradio._current_task:
return original_fn(*args, **kwargs)
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
kwargs["root_path_in_servers"] = False
kwargs["server_name"] = PatchGradio._default_gradio_address kwargs["server_name"] = PatchGradio._default_gradio_address
kwargs["server_port"] = PatchGradio._default_gradio_port kwargs["server_port"] = PatchGradio._default_gradio_port
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
# noinspection PyBroadException
try:
return original_fn(*args, **kwargs)
except Exception as e:
del kwargs["root_path"]
return original_fn(*args, **kwargs) return original_fn(*args, **kwargs)
@classmethod @classmethod
def __warn_on_server_config(cls, server_name, server_port): def __warn_on_server_config(cls, server_name, server_port, root_path):
if server_name is None and server_port is None: if (server_name is None or server_name == PatchGradio._default_gradio_address) and \
(server_port is None and server_port == PatchGradio._default_gradio_port):
return return
if (server_name, server_port, root_path) in cls.__server_config_warning:
return
cls.__server_config_warning.add((server_name, server_port, root_path))
if server_name is not None and server_port is not None: if server_name is not None and server_port is not None:
server_config = "{}:{}".format(server_name, server_port) server_config = "{}:{}".format(server_name, server_port)
what_to_ignore = "name and port" what_to_ignore = "name and port"
@ -77,11 +91,14 @@ class PatchGradio:
else: else:
server_config = str(server_port) server_config = str(server_port)
what_to_ignore = "port" what_to_ignore = "port"
if server_config in cls.__server_config_warning:
return
cls.__server_config_warning.add(server_config)
getLogger().warning( getLogger().warning(
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format( "ClearML only supports '{}:{}' as the Gradio server. Ignoring {} '{}' in remote execution".format(
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
) )
) )
if root_path is not None:
getLogger().warning(
"ClearML will override root_path '{}' to '{}' in remote execution".format(
root_path, PatchGradio._root_path_format.format(PatchGradio._current_task.id)
)
)

View File

@ -209,7 +209,10 @@ class PatchJsonArgParse(object):
with change_to_path_dir(path): with change_to_path_dir(path):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand: if subcommand:
parsed_cfg = {subcommand + PatchJsonArgParse._commands_sep + k: v for k, v in parsed_cfg.items()} parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg return parsed_cfg
@staticmethod @staticmethod

View File

@ -237,6 +237,12 @@ def parse_known_host(parsed_host):
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
elif parsed_host.port is None:
print('Web app hosted on standard port using ' + parsed_host.scheme + ' protocol.')
print('Assuming files and api ports are unchanged and use the same (' + parsed_host.scheme + ') protocol')
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
files_host = parsed_host.scheme + "://" + parsed_host.netloc+ ':8081' + parsed_host.path
else: else:
print("Warning! Could not parse host name") print("Warning! Could not parse host name")
api_host = '' api_host = ''

View File

@ -7,7 +7,11 @@ from typing import Sequence
from pathlib2 import Path from pathlib2 import Path
import clearml.backend_api.session
from clearml.datasets import Dataset from clearml.datasets import Dataset
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-data", __version__)
def check_null_id(args): def check_null_id(args):

View File

@ -1,7 +1,9 @@
import sys
import json import json
import sys
from argparse import ArgumentParser, RawTextHelpFormatter from argparse import ArgumentParser, RawTextHelpFormatter
import clearml.backend_api.session
from clearml import Task
from clearml.automation import ( from clearml.automation import (
DiscreteParameterRange, DiscreteParameterRange,
UniformIntegerParameterRange, UniformIntegerParameterRange,
@ -11,8 +13,11 @@ from clearml.automation import (
RandomSearch, RandomSearch,
GridSearch, GridSearch,
) )
from clearml import Task
from clearml.backend_interface.task.populate import CreateAndPopulate from clearml.backend_interface.task.populate import CreateAndPopulate
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-param-search", __version__)
try: try:
from clearml.automation.optuna import OptimizerOptuna # noqa from clearml.automation.optuna import OptimizerOptuna # noqa

View File

@ -3,9 +3,12 @@ from argparse import ArgumentParser
from pathlib2 import Path from pathlib2 import Path
import clearml.backend_api.session
from clearml import Task from clearml import Task
from clearml.version import __version__
from clearml.backend_interface.task.populate import CreateAndPopulate from clearml.backend_interface.task.populate import CreateAndPopulate
from clearml.version import __version__
clearml.backend_api.session.Session.add_client("clearml-task", __version__)
def setup_parser(parser): def setup_parser(parser):

View File

@ -1877,6 +1877,7 @@ class Dataset(object):
ids=None, # type: Optional[Sequence[str]] ids=None, # type: Optional[Sequence[str]]
only_completed=True, # type: bool only_completed=True, # type: bool
recursive_project_search=True, # type: bool recursive_project_search=True, # type: bool
include_archived=True, # type: bool
): ):
# type: (...) -> List[dict] # type: (...) -> List[dict]
""" """
@ -1890,9 +1891,16 @@ class Dataset(object):
:param recursive_project_search: If True and the `dataset_project` argument is set, :param recursive_project_search: If True and the `dataset_project` argument is set,
search inside subprojects as well. search inside subprojects as well.
If False, don't search inside subprojects (except for the special `.datasets` subproject) If False, don't search inside subprojects (except for the special `.datasets` subproject)
:param include_archived: If True, include archived datasets as well.
:return: List of dictionaries with dataset information :return: List of dictionaries with dataset information
Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},] Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},]
""" """
# if include_archived is False, we need to add the system tag __$not:archived to filter out archived datasets
if not include_archived:
system_tags = ["__$all", cls.__tag, "__$not", "archived"]
else:
system_tags = [cls.__tag]
if dataset_project: if dataset_project:
if not recursive_project_search: if not recursive_project_search:
dataset_projects = [ dataset_projects = [
@ -1903,12 +1911,13 @@ class Dataset(object):
dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))] dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))]
else: else:
dataset_projects = None dataset_projects = None
# noinspection PyProtectedMember # noinspection PyProtectedMember
datasets = Task._query_tasks( datasets = Task._query_tasks(
task_ids=ids or None, task_ids=ids or None,
project_name=dataset_projects, project_name=dataset_projects,
task_name=partial_name, task_name=partial_name,
system_tags=[cls.__tag], system_tags=system_tags,
type=[str(Task.TaskTypes.data_processing)], type=[str(Task.TaskTypes.data_processing)],
tags=tags or None, tags=tags or None,
status=["stopped", "published", "completed", "closed"] if only_completed else None, status=["stopped", "published", "completed", "closed"] if only_completed else None,
@ -2278,7 +2287,8 @@ class Dataset(object):
ds = Dataset.get(dependency) ds = Dataset.get(dependency)
links.update(ds._dataset_link_entries) links.update(ds._dataset_link_entries)
links.update(self._dataset_link_entries) links.update(self._dataset_link_entries)
def _download_link(link,target_path):
def _download_link(link, target_path):
if os.path.exists(target_path): if os.path.exists(target_path):
LoggerRoot.get_base_logger().info( LoggerRoot.get_base_logger().info(
"{} already exists. Skipping downloading {}".format( "{} already exists. Skipping downloading {}".format(
@ -2310,16 +2320,12 @@ class Dataset(object):
if not max_workers: if not max_workers:
for relative_path, link in links.items(): for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path) target_path = os.path.join(target_folder, relative_path)
_download_link(link,target_path) _download_link(link, target_path)
else: else:
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
for relative_path, link in links.items(): for relative_path, link in links.items():
target_path = os.path.join(target_folder, relative_path) target_path = os.path.join(target_folder, relative_path)
pool.submit(_download_link,link,target_path) pool.submit(_download_link, link, target_path)
def _extract_dataset_archive( def _extract_dataset_archive(
self, self,
@ -2732,6 +2738,7 @@ class Dataset(object):
) )
) )
) )
def _build_dependency_chunk_lookup(self): def _build_dependency_chunk_lookup(self):
# type: () -> Dict[str, int] # type: () -> Dict[str, int]
""" """

View File

@ -199,6 +199,10 @@ class Logger(object):
""" """
For explicit reporting, plot a vector as (default stacked) histogram. For explicit reporting, plot a vector as (default stacked) histogram.
.. note::
This method is the same as :meth:`Logger.report_histogram`.
This method is deprecated, use :meth:`Logger.report_histogram` instead.
For example: For example:
.. code-block:: py .. code-block:: py
@ -224,6 +228,11 @@ class Logger(object):
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/ See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'} example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
""" """
warnings.warn(
":meth:`Logger.report_vector` is deprecated;"
"use :meth:`Logger.report_histogram` instead.",
DeprecationWarning
)
self._touch_title_series(title, series) self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels, return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout) xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
@ -439,7 +448,16 @@ class Logger(object):
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly :param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/ See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}} example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
.. note::
This method is the same as :meth:`Logger.report_scatter2d` with :param:`mode='lines'`.
This method is deprecated, use :meth:`Logger.report_scatter2d` instead.
""" """
warnings.warn(
":meth:`Logger.report_line_plot` is deprecated;"
"use :meth:`Logger.report_scatter2d` instead, e.g., with :param:`mode='lines'`.",
DeprecationWarning
)
# noinspection PyArgumentList # noinspection PyArgumentList
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series] series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
@ -710,6 +728,7 @@ class Logger(object):
.. note:: .. note::
This method is the same as :meth:`Logger.report_confusion_matrix`. This method is the same as :meth:`Logger.report_confusion_matrix`.
This method is deprecated, use :meth:`Logger.report_confusion_matrix` instead.
:param str title: The title (metric) of the plot. :param str title: The title (metric) of the plot.
:param str series: The series name (variant) of the reported confusion matrix. :param str series: The series name (variant) of the reported confusion matrix.
@ -719,11 +738,16 @@ class Logger(object):
:param str yaxis: The y-axis title. (Optional) :param str yaxis: The y-axis title. (Optional)
:param list(str) xlabels: Labels for each column of the matrix. (Optional) :param list(str) xlabels: Labels for each column of the matrix. (Optional)
:param list(str) ylabels: Labels for each row of the matrix. (Optional) :param list(str) ylabels: Labels for each row of the matrix. (Optional)
:param bool yaxis_reversed: If False, 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner :param bool yaxis_reversed: If False, 0,0 is in the bottom left corner. If True, 0,0 is in the top left corner
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly :param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/ See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}} example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
""" """
warnings.warn(
":meth:`Logger.report_matrix` is deprecated;"
"use :meth:`Logger.report_confusion_matrix` instead.",
DeprecationWarning
)
self._touch_title_series(title, series) self._touch_title_series(title, series)
return self.report_confusion_matrix(title, series, matrix, iteration or 0, return self.report_confusion_matrix(title, series, matrix, iteration or 0,
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels, xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,

View File

@ -1783,7 +1783,7 @@ class StorageHelper(object):
Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3) Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3)
""" """
_temp_download_suffix = '.partially' _temp_download_suffix = '.partially'
_quotable_uri_schemes = set(_HttpDriver.schemes) | set([_GoogleCloudStorageDriver.scheme]) _quotable_uri_schemes = set(_HttpDriver.schemes)
@classmethod @classmethod
def _get_logger(cls): def _get_logger(cls):
@ -2414,7 +2414,7 @@ class StorageHelper(object):
if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema
# quote link # quote link
result_dest_path = quote_url(result_dest_path) result_dest_path = quote_url(result_dest_path, StorageHelper._quotable_uri_schemes)
return result_dest_path return result_dest_path
@ -2436,7 +2436,7 @@ class StorageHelper(object):
# quote link # quote link
def callback(result): def callback(result):
return a_cb(quote_url(result_path) if result else result) return a_cb(quote_url(result_path, StorageHelper._quotable_uri_schemes) if result else result)
# replace callback with wrapper # replace callback with wrapper
cb = callback cb = callback
@ -2463,7 +2463,7 @@ class StorageHelper(object):
retries=retries, retries=retries,
return_canonized=return_canonized) return_canonized=return_canonized)
if res: if res:
result_path = quote_url(result_path) result_path = quote_url(result_path, StorageHelper._quotable_uri_schemes)
return result_path return result_path
def list(self, prefix=None, with_metadata=False): def list(self, prefix=None, with_metadata=False):

View File

@ -42,9 +42,9 @@ def get_config_object_matcher(**patterns):
return _matcher return _matcher
def quote_url(url): def quote_url(url, valid_schemes=("http", "https")):
parsed = urlparse(url) parsed = urlparse(url)
if parsed.scheme not in ("http", "https", "gs"): if parsed.scheme not in valid_schemes:
return url return url
parsed = parsed._replace(path=quote(parsed.path)) parsed = parsed._replace(path=quote(parsed.path))
return urlunparse(parsed) return urlunparse(parsed)

View File

@ -1009,10 +1009,16 @@ class Task(_Task):
If None is passed, returns all tasks within the project If None is passed, returns all tasks within the project
:param list tags: Filter based on the requested list of tags (strings) :param list tags: Filter based on the requested list of tags (strings)
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"] To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
To include All tags (instead of the default All behaviour) use "__$all" as the first string, example: The default behaviour is to join all tags with a logical "OR" operator.
To join all tags with a logical "AND" operator instead, use "__$all" as the first string, for example:
["__$all", "best", "experiment", "ever"] ["__$all", "best", "experiment", "ever"]
To combine All tags and exclude a list of tags use "__$not" before the excluded tags, example: To join all tags with AND, but exclude a tag use "__$not" before the excluded tag, for example:
["__$all", "best", "experiment", "ever", "__$not", "internal", "test"] ["__$all", "best", "experiment", "ever", "__$not", "internal", "__$not", "test"]
The "OR" and "AND" operators apply to all tags that follow them until another operator is specified.
The NOT operator applies only to the immediately following tag.
For example, ["__$all", "a", "b", "c", "__$or", "d", "__$not", "e", "__$and", "__$or" "f", "g"]
means ("a" AND "b" AND "c" AND ("d" OR NOT "e") AND ("f" OR "g")).
See https://clear.ml/docs/latest/docs/clearml_sdk/task_sdk/#tag-filters for more information.
:param list additional_return_fields: Optional, if not provided return a list of Task IDs. :param list additional_return_fields: Optional, if not provided return a list of Task IDs.
If provided return dict per Task with the additional requested fields. If provided return dict per Task with the additional requested fields.
Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict: Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict:
@ -3449,8 +3455,8 @@ class Task(_Task):
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
task_artifacts = task.data.execution.artifacts \ task_artifacts = task.data.execution.artifacts \
if hasattr(task.data.execution, 'artifacts') else None if hasattr(task.data.execution, 'artifacts') else None
if ((str(task._status) in ( if ((task._status in (
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed))) cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
or task.output_models_id or (cls.archived_tag in task_tags) or task.output_models_id or (cls.archived_tag in task_tags)
or (cls._development_tag not in task_tags) or (cls._development_tag not in task_tags)
or task_artifacts): or task_artifacts):
@ -3621,7 +3627,10 @@ class Task(_Task):
# at least until we support multiple input models # at least until we support multiple input models
# notice that we do not check the task's input model because we allow task reuse and overwrite # notice that we do not check the task's input model because we allow task reuse and overwrite
# add into comment that we are using this model # add into comment that we are using this model
comment = self.comment or ''
# refresh comment
comment = self._reload_field("comment") or self.comment or ''
if not comment.endswith('\n'): if not comment.endswith('\n'):
comment += '\n' comment += '\n'
comment += 'Using model id: {}'.format(model.id) comment += 'Using model id: {}'.format(model.id)
@ -3673,14 +3682,14 @@ class Task(_Task):
# noinspection PyProtectedMember # noinspection PyProtectedMember
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name) task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
def _refresh_args_dict(task, config_dict): def _refresh_args_dict(task, config_proxy_dict):
# reread from task including newly added keys # reread from task including newly added keys
# noinspection PyProtectedMember # noinspection PyProtectedMember
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name) a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_proxy_dict), prefix=name)
# noinspection PyProtectedMember # noinspection PyProtectedMember
nested_dict = config_dict._to_dict() nested_dict = config_proxy_dict._to_dict()
config_dict.clear() config_proxy_dict.clear()
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict)) config_proxy_dict._do_update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
def _check_keys(dict_, warning_sent=False): def _check_keys(dict_, warning_sent=False):
if warning_sent: if warning_sent:
@ -4606,15 +4615,15 @@ class Task(_Task):
return False return False
stopped_statuses = ( stopped_statuses = (
str(tasks.TaskStatusEnum.stopped), cls.TaskStatusEnum.stopped,
str(tasks.TaskStatusEnum.published), cls.TaskStatusEnum.published,
str(tasks.TaskStatusEnum.publishing), cls.TaskStatusEnum.publishing,
str(tasks.TaskStatusEnum.closed), cls.TaskStatusEnum.closed,
str(tasks.TaskStatusEnum.failed), cls.TaskStatusEnum.failed,
str(tasks.TaskStatusEnum.completed), cls.TaskStatusEnum.completed,
) )
if str(task.status) not in stopped_statuses: if task.status not in stopped_statuses:
cls._send( cls._send(
cls._get_default_session(), cls._get_default_session(),
tasks.StoppedRequest( tasks.StoppedRequest(

View File

@ -61,6 +61,23 @@ def config_dict_to_text(config):
try: try:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
def raise_on_special_key(config_):
if not isinstance(config_, dict):
return
special_chars = "$}[]:=+#`^?!@*&."
for key in config_.keys():
if not isinstance(key, str):
continue
if any(key_char in special_chars for key_char in key):
raise ValueError(
"Configuration dictionary keys cannot contain any of the following characters: {}".format(
special_chars
)
)
for val in config_.values():
raise_on_special_key(val)
# will fall back to json+pyhocon
raise_on_special_key(config)
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config))) text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config)))
except Exception: except Exception:
# fallback json+pyhocon # fallback json+pyhocon

View File

@ -39,10 +39,14 @@ class ProxyDictPostWrite(dict):
return a_dict return a_dict
def update(self, E=None, **F): def update(self, E=None, **F):
res = self._do_update(E, **F)
self._set_callback()
return res
def _do_update(self, E=None, **F):
res = super(ProxyDictPostWrite, self).update( res = super(ProxyDictPostWrite, self).update(
ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
ProxyDictPostWrite(self._update_obj, self._set_callback, **F)) ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
self._set_callback()
return res return res

View File

@ -1 +1 @@
__version__ = '1.11.1rc1' __version__ = '1.11.2rc0'

View File

@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
""" """
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os import os
from tempfile import gettempdir from tempfile import gettempdir
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torchvision import datasets, transforms from torchvision import datasets, transforms
from clearml import Task, Logger from clearml import Task, Logger
@ -127,8 +130,9 @@ def main():
task.execute_remotely(queue_name="default") task.execute_remotely(queue_name="default")
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch) test(args, model, device, test_loader, epoch)
if (args.save_model):
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,6 +8,6 @@ trainer:
filename: best filename: best
save_last: False save_last: False
save_top_k: 1 save_top_k: 1
monitor: loss monitor: train_loss
mode: min mode: min
max_epochs: 10 max_epochs: 10

View File

@ -78,6 +78,7 @@ class ImageClassifier(LightningModule):
x, y = batch x, y = batch
logits = self.forward(x) logits = self.forward(x)
loss = F.nll_loss(logits, y.long()) loss = F.nll_loss(logits, y.long())
self.log("train_loss", loss)
return loss return loss
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):

View File

@ -8,5 +8,6 @@ trainer:
filename: best filename: best
save_last: False save_last: False
save_top_k: 1 save_top_k: 1
monitor: loss monitor: train_loss
mode: min mode: min
max_epochs: 3

View File

@ -156,7 +156,7 @@ def main(_):
test(FLAGS, model, device, test_loader, epoch) test(FLAGS, model, device, test_loader, epoch)
if FLAGS.save_model: if FLAGS.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
# ClearML - Example of Pytorch mnist training integration # ClearML - Example of Pytorch mnist training integration
# #
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os import os
from tempfile import gettempdir from tempfile import gettempdir
@ -128,7 +129,7 @@ def main():
train(args, model, device, train_loader, optimizer, epoch) train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader, epoch) test(args, model, device, test_loader, epoch)
if (args.save_model): if args.save_model:
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt")) torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))

View File

@ -7,3 +7,4 @@ tensorboardX
torch>=1.1.0 torch>=1.1.0
torchvision>=0.3.0 torchvision>=0.3.0
tqdm tqdm
protobuf>=4.21.1