mirror of
https://github.com/clearml/clearml
synced 2025-06-22 17:45:43 +00:00
Merge branch 'master' of https://github.com/allegroai/clearml
This commit is contained in:
commit
999c6543b8
@ -76,7 +76,7 @@ Instrumenting these components is the **ClearML-server**, see [Self-Hosting](htt
|
||||
* Full source control info including non-committed local changes
|
||||
* Execution environment (including specific packages & versions)
|
||||
* Hyper-parameters
|
||||
* ArgParser/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
|
||||
* [`argparse`](https://docs.python.org/3/library/argparse.html)/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
|
||||
* Explicit parameters dictionary
|
||||
* Tensorflow Defines (absl-py)
|
||||
* [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides
|
||||
|
@ -1615,7 +1615,7 @@ class PipelineController(object):
|
||||
|
||||
def _has_stored_configuration(self):
|
||||
"""
|
||||
Return True if we are running remotely and we have stored configuration on the Task
|
||||
Return True if we are running remotely, and we have stored configuration on the Task
|
||||
"""
|
||||
if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task():
|
||||
stored_config = self._task.get_configuration_object(self._config_section)
|
||||
@ -1698,9 +1698,10 @@ class PipelineController(object):
|
||||
conformed_monitors = [
|
||||
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
|
||||
]
|
||||
# verify pair of pairs
|
||||
# verify the pair of pairs
|
||||
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
|
||||
isinstance(x[1][0], str) and isinstance(x[1][1], str) for x in conformed_monitors):
|
||||
isinstance(x[1][0], str) and isinstance(x[1][1], str)
|
||||
for x in conformed_monitors):
|
||||
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
else:
|
||||
# verify a list of tuples
|
||||
@ -1711,8 +1712,10 @@ class PipelineController(object):
|
||||
conformed_monitors = [
|
||||
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
|
||||
]
|
||||
# verify pair of pairs
|
||||
if not all(isinstance(x[0], str) and isinstance(x[1], str) for x in conformed_monitors):
|
||||
# verify the pair of pairs
|
||||
if not all(isinstance(x[0], str) and
|
||||
isinstance(x[1], str)
|
||||
for x in conformed_monitors):
|
||||
raise ValueError(
|
||||
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||
|
||||
@ -1782,7 +1785,7 @@ class PipelineController(object):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
Create a Task from a function, including wrapping the function input arguments
|
||||
into the hyper-parameter section as kwargs, and storing function results as named artifacts
|
||||
into the hyperparameter section as kwargs, and storing function results as named artifacts
|
||||
|
||||
Example:
|
||||
|
||||
@ -1874,7 +1877,7 @@ class PipelineController(object):
|
||||
:param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop
|
||||
(or marked as failed). Notice, that steps that are connected (or indirectly connected)
|
||||
to the failed step will be skipped.
|
||||
:param pre_execute_callback: Callback function, called when the step (Task) is created
|
||||
:param pre_execute_callback: Callback function, called when the step (Task) is created,
|
||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||
@ -2302,7 +2305,8 @@ class PipelineController(object):
|
||||
|
||||
:param node_params: list of node parameters
|
||||
:param visited: list of nodes
|
||||
:return: Table as List of List of strings (cell)
|
||||
|
||||
:return: Table as a List of a List of strings (cell)
|
||||
"""
|
||||
task_link_template = self._task.get_output_log_web_page() \
|
||||
.replace('/{}/'.format(self._task.project), '/{project}/') \
|
||||
@ -2778,6 +2782,7 @@ class PipelineController(object):
|
||||
return the pipeline components target folder name/id
|
||||
|
||||
:param return_project_id: if False (default), return target folder name. If True, return project id
|
||||
|
||||
:return: project id/name (None if not valid)
|
||||
"""
|
||||
if not self._target_project:
|
||||
@ -3110,7 +3115,7 @@ class PipelineDecorator(PipelineController):
|
||||
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
||||
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
||||
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
||||
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
|
||||
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
|
||||
was specifically defined with "continue_on_fail=True".
|
||||
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
||||
and mark the pipeline as failed.
|
||||
@ -3540,7 +3545,7 @@ class PipelineDecorator(PipelineController):
|
||||
|
||||
:param _func: wrapper function
|
||||
:param return_values: Provide a list of names for all the results.
|
||||
Notice! If not provided no results will be stored as artifacts.
|
||||
Notice! If not provided, no results will be stored as artifacts.
|
||||
:param name: Optional, set the name of the pipeline component task.
|
||||
If not provided, the wrapped function name is used as the pipeline component name
|
||||
:param cache: If True, before launching the new step,
|
||||
@ -3623,8 +3628,8 @@ class PipelineDecorator(PipelineController):
|
||||
# allow up to 5 retries (total of 6 runs)
|
||||
return retries < 5
|
||||
|
||||
:param pre_execute_callback: Callback function, called when the step (Task) is created
|
||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||
:param pre_execute_callback: Callback function, called when the step (Task) is created,
|
||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||
|
||||
@ -3644,7 +3649,7 @@ class PipelineDecorator(PipelineController):
|
||||
pass
|
||||
|
||||
:param post_execute_callback: Callback function, called when a step (Task) is completed
|
||||
and other jobs are executed. Allows a user to modify the Task status after completion.
|
||||
and other jobs are going to be executed. Allows a user to modify the Task status after completion.
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
@ -3794,7 +3799,7 @@ class PipelineDecorator(PipelineController):
|
||||
if target_queue:
|
||||
PipelineDecorator.set_default_execution_queue(target_queue)
|
||||
else:
|
||||
# if we are are not running from a queue, we are probably in debug mode
|
||||
# if we are not running from a queue, we are probably in debug mode
|
||||
a_pipeline._clearml_job_class = LocalClearmlJob
|
||||
a_pipeline._default_execution_queue = 'mock'
|
||||
|
||||
@ -3957,7 +3962,7 @@ class PipelineDecorator(PipelineController):
|
||||
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
||||
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
||||
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
||||
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
|
||||
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
|
||||
was specifically defined with "continue_on_fail=True".
|
||||
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
||||
and mark the pipeline as failed.
|
||||
|
@ -384,9 +384,10 @@ class BaseJob(object):
|
||||
section_overrides=None,
|
||||
params_override=None,
|
||||
configurations_override=None,
|
||||
explicit_docker_image=None
|
||||
explicit_docker_image=None,
|
||||
account_for_artifacts_hashes=True
|
||||
):
|
||||
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str]) -> Optional[str]
|
||||
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str], bool) -> Optional[str]
|
||||
"""
|
||||
Create Hash (str) representing the state of the Task
|
||||
|
||||
@ -397,6 +398,8 @@ class BaseJob(object):
|
||||
:param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem)
|
||||
:param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image
|
||||
was explicitly changed
|
||||
:param account_for_artifacts_hashes: Calculate the hash of the task by accounting for the hashes of the
|
||||
artifacts in `kwargs_artifacts` (as opposed of the task ID/artifact name stored in this section)
|
||||
|
||||
:return: str hash of the Task configuration
|
||||
"""
|
||||
@ -416,7 +419,23 @@ class BaseJob(object):
|
||||
# we need to ignore `requirements` section because ir might be changing from run to run
|
||||
script.pop("requirements", None)
|
||||
|
||||
hyper_params = task.get_parameters() if params_override is None else params_override
|
||||
hyper_params = deepcopy(task.get_parameters() if params_override is None else params_override)
|
||||
if account_for_artifacts_hashes:
|
||||
hyper_params_to_change = {}
|
||||
task_cache = {}
|
||||
for key, value in hyper_params.items():
|
||||
if key.startswith("kwargs_artifacts/"):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# key format is <task_id>.<artifact_name>
|
||||
task_id, artifact = value.split(".", 1)
|
||||
task_ = task_cache.setdefault(task_id, Task.get_task(task_id))
|
||||
# set the value of the hyper parameter to the hash of the artifact
|
||||
# because the task ID might differ, but the artifact might be the same
|
||||
hyper_params_to_change[key] = task_.artifacts[artifact].hash
|
||||
except Exception:
|
||||
pass
|
||||
hyper_params.update(hyper_params_to_change)
|
||||
configs = task.get_configuration_objects() if configurations_override is None else configurations_override
|
||||
# currently we do not add the docker image to the hash (only args and setup script),
|
||||
# because default docker image will cause the step to change
|
||||
@ -585,6 +604,14 @@ class ClearmlJob(BaseJob):
|
||||
if allow_caching:
|
||||
# look for a cached copy of the Task
|
||||
# get parameters + task_overrides + as dict and hash it.
|
||||
task_hash_legacy = self._create_task_hash(
|
||||
base_temp_task,
|
||||
section_overrides=sections,
|
||||
params_override=task_params,
|
||||
configurations_override=configuration_overrides or None,
|
||||
explicit_docker_image=kwargs.get("explicit_docker_image"),
|
||||
account_for_artifacts_hashes=False
|
||||
)
|
||||
task_hash = self._create_task_hash(
|
||||
base_temp_task,
|
||||
section_overrides=sections,
|
||||
@ -592,7 +619,7 @@ class ClearmlJob(BaseJob):
|
||||
configurations_override=configuration_overrides or None,
|
||||
explicit_docker_image=kwargs.get("explicit_docker_image")
|
||||
)
|
||||
task = self._get_cached_task(task_hash)
|
||||
task = self._get_cached_task(task_hash_legacy) or self._get_cached_task(task_hash)
|
||||
# if we found a task, just use
|
||||
if task:
|
||||
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:
|
||||
|
@ -4590,7 +4590,7 @@ class UpdateResponse(Response):
|
||||
|
||||
class ValidateDeleteRequest(Request):
|
||||
"""
|
||||
Validates that the project existis and can be deleted
|
||||
Validates that the project exists and can be deleted
|
||||
|
||||
:param project: Project ID
|
||||
:type project: str
|
||||
|
@ -4616,7 +4616,7 @@ class UpdateResponse(Response):
|
||||
|
||||
class ValidateDeleteRequest(Request):
|
||||
"""
|
||||
Validates that the project existis and can be deleted
|
||||
Validates that the project exists and can be deleted
|
||||
|
||||
:param project: Project ID
|
||||
:type project: str
|
||||
|
@ -156,6 +156,19 @@ class Session(TokenManager):
|
||||
|
||||
self._connect()
|
||||
|
||||
@classmethod
|
||||
def add_client(cls, client, value, first=True):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if not any(True for c in cls._client if c[0] == client):
|
||||
if first:
|
||||
cls._client.insert(0, (client, value))
|
||||
else:
|
||||
cls._client.append((client, value))
|
||||
cls.client = ", ".join("{}-{}".format(*x) for x in cls._client)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _connect(self):
|
||||
if self._offline_mode:
|
||||
return
|
||||
@ -219,8 +232,7 @@ class Session(TokenManager):
|
||||
if not api_version:
|
||||
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||
if token_dict.get('server_version'):
|
||||
if not any(True for c in Session._client if c[0] == 'clearml-server'):
|
||||
Session._client.append(('clearml-server', token_dict.get('server_version'), ))
|
||||
self.add_client('clearml-server', token_dict.get('server_version'))
|
||||
|
||||
Session.max_api_version = Session.api_version = str(api_version)
|
||||
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")
|
||||
|
@ -644,6 +644,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||
return res.response.task
|
||||
|
||||
def _reload_field(self, field):
|
||||
# type: (str) -> Any
|
||||
""" Reload the task specific field, dot seperated for nesting"""
|
||||
with self._edit_lock:
|
||||
if self._offline_mode:
|
||||
task_object = self._reload()
|
||||
else:
|
||||
res = self.send(tasks.GetAllRequest(id=[self.id], only_fields=[field], search_hidden=True))
|
||||
task_object = res.response.tasks[0]
|
||||
|
||||
for p in field.split("."):
|
||||
task_object = getattr(task_object, p, None)
|
||||
if task_object is None:
|
||||
break
|
||||
|
||||
return task_object
|
||||
|
||||
def reset(self, set_started_on_success=True, force=False):
|
||||
# type: (bool, bool) -> ()
|
||||
"""
|
||||
@ -752,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def publish(self, ignore_errors=True):
|
||||
# type: (bool) -> ()
|
||||
""" The signal that this task will be published """
|
||||
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)):
|
||||
if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
|
||||
raise ValueError("Can't publish, Task is not stopped")
|
||||
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
|
||||
assert isinstance(resp.response, tasks.PublishResponse)
|
||||
@ -792,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
try:
|
||||
res = self.send(tasks.GetByIdRequest(self.task_id))
|
||||
task = res.response.task
|
||||
if task.status == Task.TaskStatusEnum.published:
|
||||
if task.status == self.TaskStatusEnum.published:
|
||||
if raise_on_error:
|
||||
raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
|
||||
self.log.error("Cannot delete published task {}".format(self.task_id))
|
||||
@ -2408,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
:return: list of tuples (status, status_message, task_id)
|
||||
"""
|
||||
if cls._offline_mode:
|
||||
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids]
|
||||
return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -2547,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# Since we ae using forced update, make sure he task status is valid
|
||||
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
|
||||
if not kwargs.pop("force", False) and \
|
||||
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
|
||||
status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
|
||||
# the exception being name/comment that we can always change.
|
||||
if kwargs and all(
|
||||
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
|
||||
@ -3000,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
def _get_task_status(cls, task_id):
|
||||
# type: (str) -> (Optional[str], Optional[str])
|
||||
if cls._offline_mode:
|
||||
return tasks.TaskStatusEnum.created, 'offline'
|
||||
return cls.TaskStatusEnum.created, 'offline'
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
|
@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
_checkpoint_filename = {}
|
||||
__patched = None
|
||||
__patched_lightning = None
|
||||
__patched_pytorch_lightning = None
|
||||
__patched_mmcv = None
|
||||
|
||||
@staticmethod
|
||||
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
return
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PatchPyTorchModelIO._patch_lightning_io()
|
||||
PatchPyTorchModelIO._patch_pytorch_lightning_io()
|
||||
PatchPyTorchModelIO._patch_mmcv()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||
PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
if PatchPyTorchModelIO.__patched_lightning:
|
||||
return
|
||||
|
||||
if 'pytorch_lightning' not in sys.modules:
|
||||
if 'lightning' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
|
||||
) # noqa
|
||||
|
||||
lightning.pytorch.trainer.Trainer.restore = _patched_call(
|
||||
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import lightning # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||
PatchPyTorchModelIO._save,
|
||||
) # noqa
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
|
||||
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||
PatchPyTorchModelIO._load_from_obj,
|
||||
) # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _patch_pytorch_lightning_io():
|
||||
if PatchPyTorchModelIO.__patched_pytorch_lightning:
|
||||
return
|
||||
|
||||
if 'pytorch_lightning' not in sys.modules:
|
||||
return
|
||||
|
||||
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import pytorch_lightning # noqa
|
||||
|
@ -3,6 +3,7 @@ from logging import getLogger
|
||||
from .frameworks import _patched_call # noqa
|
||||
from .import_bind import PostImportHookPatching
|
||||
from ..utilities.networking import get_private_ip
|
||||
from ..config import running_remotely
|
||||
|
||||
|
||||
class PatchGradio:
|
||||
@ -11,7 +12,7 @@ class PatchGradio:
|
||||
|
||||
_default_gradio_address = "0.0.0.0"
|
||||
_default_gradio_port = 7860
|
||||
_root_path_format = "/service/{}"
|
||||
_root_path_format = "/service/{}/"
|
||||
__server_config_warning = set()
|
||||
|
||||
@classmethod
|
||||
@ -32,42 +33,55 @@ class PatchGradio:
|
||||
try:
|
||||
import gradio
|
||||
|
||||
gradio.networking.start_server = _patched_call(
|
||||
gradio.networking.start_server, PatchGradio._patched_start_server
|
||||
)
|
||||
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
||||
gradio.routes.App.get_blocks = _patched_call(gradio.routes.App.get_blocks, PatchGradio._patched_get_blocks)
|
||||
gradio.blocks.Blocks.launch = _patched_call(gradio.blocks.Blocks.launch, PatchGradio._patched_launch)
|
||||
except Exception:
|
||||
pass
|
||||
cls.__patched = True
|
||||
|
||||
@staticmethod
|
||||
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
||||
def _patched_get_blocks(original_fn, *args, **kwargs):
|
||||
blocks = original_fn(*args, **kwargs)
|
||||
if not PatchGradio._current_task or not running_remotely():
|
||||
return blocks
|
||||
blocks.config["root"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||
blocks.root = blocks.config["root"]
|
||||
return blocks
|
||||
|
||||
@staticmethod
|
||||
def _patched_launch(original_fn, *args, **kwargs):
|
||||
if not PatchGradio._current_task:
|
||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||
return original_fn(*args, **kwargs)
|
||||
PatchGradio.__warn_on_server_config(
|
||||
kwargs.get("server_name"),
|
||||
kwargs.get("server_port"),
|
||||
kwargs.get("root_path")
|
||||
)
|
||||
if not running_remotely():
|
||||
return original_fn(*args, **kwargs)
|
||||
# noinspection PyProtectedMember
|
||||
PatchGradio._current_task._set_runtime_properties(
|
||||
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
||||
)
|
||||
PatchGradio._current_task.set_system_tags(["external_service"])
|
||||
PatchGradio.__warn_on_server_config(server_name, server_port)
|
||||
server_name = PatchGradio._default_gradio_address
|
||||
server_port = PatchGradio._default_gradio_port
|
||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _patched_init(original_fn, *args, **kwargs):
|
||||
if not PatchGradio._current_task:
|
||||
return original_fn(*args, **kwargs)
|
||||
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
|
||||
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||
kwargs["root_path_in_servers"] = False
|
||||
kwargs["server_name"] = PatchGradio._default_gradio_address
|
||||
kwargs["server_port"] = PatchGradio._default_gradio_port
|
||||
return original_fn(*args, **kwargs)
|
||||
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return original_fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
del kwargs["root_path"]
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def __warn_on_server_config(cls, server_name, server_port):
|
||||
if server_name is None and server_port is None:
|
||||
def __warn_on_server_config(cls, server_name, server_port, root_path):
|
||||
if (server_name is None or server_name == PatchGradio._default_gradio_address) and \
|
||||
(server_port is None and server_port == PatchGradio._default_gradio_port):
|
||||
return
|
||||
if (server_name, server_port, root_path) in cls.__server_config_warning:
|
||||
return
|
||||
cls.__server_config_warning.add((server_name, server_port, root_path))
|
||||
if server_name is not None and server_port is not None:
|
||||
server_config = "{}:{}".format(server_name, server_port)
|
||||
what_to_ignore = "name and port"
|
||||
@ -77,11 +91,14 @@ class PatchGradio:
|
||||
else:
|
||||
server_config = str(server_port)
|
||||
what_to_ignore = "port"
|
||||
if server_config in cls.__server_config_warning:
|
||||
return
|
||||
cls.__server_config_warning.add(server_config)
|
||||
getLogger().warning(
|
||||
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
|
||||
"ClearML only supports '{}:{}' as the Gradio server. Ignoring {} '{}' in remote execution".format(
|
||||
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
||||
)
|
||||
)
|
||||
if root_path is not None:
|
||||
getLogger().warning(
|
||||
"ClearML will override root_path '{}' to '{}' in remote execution".format(
|
||||
root_path, PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||
)
|
||||
)
|
||||
|
@ -209,7 +209,10 @@ class PatchJsonArgParse(object):
|
||||
with change_to_path_dir(path):
|
||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||
if subcommand:
|
||||
parsed_cfg = {subcommand + PatchJsonArgParse._commands_sep + k: v for k, v in parsed_cfg.items()}
|
||||
parsed_cfg = {
|
||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
|
||||
for k, v in parsed_cfg.items()
|
||||
}
|
||||
return parsed_cfg
|
||||
|
||||
@staticmethod
|
||||
|
@ -237,6 +237,12 @@ def parse_known_host(parsed_host):
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||
elif parsed_host.port is None:
|
||||
print('Web app hosted on standard port using ' + parsed_host.scheme + ' protocol.')
|
||||
print('Assuming files and api ports are unchanged and use the same (' + parsed_host.scheme + ') protocol')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc+ ':8081' + parsed_host.path
|
||||
else:
|
||||
print("Warning! Could not parse host name")
|
||||
api_host = ''
|
||||
|
@ -7,7 +7,11 @@ from typing import Sequence
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
import clearml.backend_api.session
|
||||
from clearml.datasets import Dataset
|
||||
from clearml.version import __version__
|
||||
|
||||
clearml.backend_api.session.Session.add_client("clearml-data", __version__)
|
||||
|
||||
|
||||
def check_null_id(args):
|
||||
|
@ -1,7 +1,9 @@
|
||||
import sys
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser, RawTextHelpFormatter
|
||||
|
||||
import clearml.backend_api.session
|
||||
from clearml import Task
|
||||
from clearml.automation import (
|
||||
DiscreteParameterRange,
|
||||
UniformIntegerParameterRange,
|
||||
@ -11,8 +13,11 @@ from clearml.automation import (
|
||||
RandomSearch,
|
||||
GridSearch,
|
||||
)
|
||||
from clearml import Task
|
||||
from clearml.backend_interface.task.populate import CreateAndPopulate
|
||||
from clearml.version import __version__
|
||||
|
||||
clearml.backend_api.session.Session.add_client("clearml-param-search", __version__)
|
||||
|
||||
|
||||
try:
|
||||
from clearml.automation.optuna import OptimizerOptuna # noqa
|
||||
|
@ -3,9 +3,12 @@ from argparse import ArgumentParser
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
import clearml.backend_api.session
|
||||
from clearml import Task
|
||||
from clearml.version import __version__
|
||||
from clearml.backend_interface.task.populate import CreateAndPopulate
|
||||
from clearml.version import __version__
|
||||
|
||||
clearml.backend_api.session.Session.add_client("clearml-task", __version__)
|
||||
|
||||
|
||||
def setup_parser(parser):
|
||||
|
@ -122,7 +122,7 @@ class Dataset(object):
|
||||
__hyperparams_section = "Datasets"
|
||||
__datasets_runtime_prop = "datasets"
|
||||
__orig_datasets_runtime_prop_prefix = "orig_datasets"
|
||||
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
|
||||
__preview_media_max_file_size = deferred_config("dataset.preview.media.max_file_size", 5 * 1024 * 1024, transform=int)
|
||||
__preview_tabular_table_count = deferred_config("dataset.preview.tabular.table_count", 10, transform=int)
|
||||
__preview_tabular_row_count = deferred_config("dataset.preview.tabular.row_count", 10, transform=int)
|
||||
__preview_media_image_count = deferred_config("dataset.preview.media.image_count", 10, transform=int)
|
||||
@ -1877,6 +1877,7 @@ class Dataset(object):
|
||||
ids=None, # type: Optional[Sequence[str]]
|
||||
only_completed=True, # type: bool
|
||||
recursive_project_search=True, # type: bool
|
||||
include_archived=True, # type: bool
|
||||
):
|
||||
# type: (...) -> List[dict]
|
||||
"""
|
||||
@ -1890,9 +1891,16 @@ class Dataset(object):
|
||||
:param recursive_project_search: If True and the `dataset_project` argument is set,
|
||||
search inside subprojects as well.
|
||||
If False, don't search inside subprojects (except for the special `.datasets` subproject)
|
||||
:param include_archived: If True, include archived datasets as well.
|
||||
:return: List of dictionaries with dataset information
|
||||
Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},]
|
||||
"""
|
||||
# if include_archived is False, we need to add the system tag __$not:archived to filter out archived datasets
|
||||
if not include_archived:
|
||||
system_tags = ["__$all", cls.__tag, "__$not", "archived"]
|
||||
else:
|
||||
system_tags = [cls.__tag]
|
||||
|
||||
if dataset_project:
|
||||
if not recursive_project_search:
|
||||
dataset_projects = [
|
||||
@ -1903,12 +1911,13 @@ class Dataset(object):
|
||||
dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))]
|
||||
else:
|
||||
dataset_projects = None
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
datasets = Task._query_tasks(
|
||||
task_ids=ids or None,
|
||||
project_name=dataset_projects,
|
||||
task_name=partial_name,
|
||||
system_tags=[cls.__tag],
|
||||
system_tags=system_tags,
|
||||
type=[str(Task.TaskTypes.data_processing)],
|
||||
tags=tags or None,
|
||||
status=["stopped", "published", "completed", "closed"] if only_completed else None,
|
||||
@ -2278,14 +2287,15 @@ class Dataset(object):
|
||||
ds = Dataset.get(dependency)
|
||||
links.update(ds._dataset_link_entries)
|
||||
links.update(self._dataset_link_entries)
|
||||
def _download_link(link,target_path):
|
||||
|
||||
def _download_link(link, target_path):
|
||||
if os.path.exists(target_path):
|
||||
LoggerRoot.get_base_logger().info(
|
||||
"{} already exists. Skipping downloading {}".format(
|
||||
target_path, link
|
||||
)
|
||||
LoggerRoot.get_base_logger().info(
|
||||
"{} already exists. Skipping downloading {}".format(
|
||||
target_path, link
|
||||
)
|
||||
return
|
||||
)
|
||||
return
|
||||
ok = False
|
||||
error = None
|
||||
try:
|
||||
@ -2310,16 +2320,12 @@ class Dataset(object):
|
||||
if not max_workers:
|
||||
for relative_path, link in links.items():
|
||||
target_path = os.path.join(target_folder, relative_path)
|
||||
_download_link(link,target_path)
|
||||
_download_link(link, target_path)
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
for relative_path, link in links.items():
|
||||
target_path = os.path.join(target_folder, relative_path)
|
||||
pool.submit(_download_link,link,target_path)
|
||||
|
||||
|
||||
|
||||
|
||||
pool.submit(_download_link, link, target_path)
|
||||
|
||||
def _extract_dataset_archive(
|
||||
self,
|
||||
@ -2720,7 +2726,7 @@ class Dataset(object):
|
||||
dataset._task.mark_completed()
|
||||
|
||||
return id
|
||||
|
||||
|
||||
def _log_dataset_page(self):
|
||||
if bool(Session.check_min_api_server_version(self.__min_api_version)):
|
||||
self._task.get_logger().report_text(
|
||||
@ -2732,6 +2738,7 @@ class Dataset(object):
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _build_dependency_chunk_lookup(self):
|
||||
# type: () -> Dict[str, int]
|
||||
"""
|
||||
|
@ -199,6 +199,10 @@ class Logger(object):
|
||||
"""
|
||||
For explicit reporting, plot a vector as (default stacked) histogram.
|
||||
|
||||
.. note::
|
||||
This method is the same as :meth:`Logger.report_histogram`.
|
||||
This method is deprecated, use :meth:`Logger.report_histogram` instead.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: py
|
||||
@ -224,6 +228,11 @@ class Logger(object):
|
||||
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
|
||||
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
|
||||
"""
|
||||
warnings.warn(
|
||||
":meth:`Logger.report_vector` is deprecated;"
|
||||
"use :meth:`Logger.report_histogram` instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
self._touch_title_series(title, series)
|
||||
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
|
||||
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
|
||||
@ -439,7 +448,16 @@ class Logger(object):
|
||||
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
||||
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
|
||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||
|
||||
.. note::
|
||||
This method is the same as :meth:`Logger.report_scatter2d` with :param:`mode='lines'`.
|
||||
This method is deprecated, use :meth:`Logger.report_scatter2d` instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
":meth:`Logger.report_line_plot` is deprecated;"
|
||||
"use :meth:`Logger.report_scatter2d` instead, e.g., with :param:`mode='lines'`.",
|
||||
DeprecationWarning
|
||||
)
|
||||
|
||||
# noinspection PyArgumentList
|
||||
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
||||
@ -710,6 +728,7 @@ class Logger(object):
|
||||
|
||||
.. note::
|
||||
This method is the same as :meth:`Logger.report_confusion_matrix`.
|
||||
This method is deprecated, use :meth:`Logger.report_confusion_matrix` instead.
|
||||
|
||||
:param str title: The title (metric) of the plot.
|
||||
:param str series: The series name (variant) of the reported confusion matrix.
|
||||
@ -719,11 +738,16 @@ class Logger(object):
|
||||
:param str yaxis: The y-axis title. (Optional)
|
||||
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
||||
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
||||
:param bool yaxis_reversed: If False, 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner
|
||||
:param bool yaxis_reversed: If False, 0,0 is in the bottom left corner. If True, 0,0 is in the top left corner
|
||||
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
||||
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
|
||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||
"""
|
||||
warnings.warn(
|
||||
":meth:`Logger.report_matrix` is deprecated;"
|
||||
"use :meth:`Logger.report_confusion_matrix` instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
self._touch_title_series(title, series)
|
||||
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
|
||||
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,
|
||||
|
@ -1783,7 +1783,7 @@ class StorageHelper(object):
|
||||
Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3)
|
||||
"""
|
||||
_temp_download_suffix = '.partially'
|
||||
_quotable_uri_schemes = set(_HttpDriver.schemes) | set([_GoogleCloudStorageDriver.scheme])
|
||||
_quotable_uri_schemes = set(_HttpDriver.schemes)
|
||||
|
||||
@classmethod
|
||||
def _get_logger(cls):
|
||||
@ -2414,7 +2414,7 @@ class StorageHelper(object):
|
||||
|
||||
if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema
|
||||
# quote link
|
||||
result_dest_path = quote_url(result_dest_path)
|
||||
result_dest_path = quote_url(result_dest_path, StorageHelper._quotable_uri_schemes)
|
||||
|
||||
return result_dest_path
|
||||
|
||||
@ -2436,7 +2436,7 @@ class StorageHelper(object):
|
||||
|
||||
# quote link
|
||||
def callback(result):
|
||||
return a_cb(quote_url(result_path) if result else result)
|
||||
return a_cb(quote_url(result_path, StorageHelper._quotable_uri_schemes) if result else result)
|
||||
# replace callback with wrapper
|
||||
cb = callback
|
||||
|
||||
@ -2463,7 +2463,7 @@ class StorageHelper(object):
|
||||
retries=retries,
|
||||
return_canonized=return_canonized)
|
||||
if res:
|
||||
result_path = quote_url(result_path)
|
||||
result_path = quote_url(result_path, StorageHelper._quotable_uri_schemes)
|
||||
return result_path
|
||||
|
||||
def list(self, prefix=None, with_metadata=False):
|
||||
|
@ -42,9 +42,9 @@ def get_config_object_matcher(**patterns):
|
||||
return _matcher
|
||||
|
||||
|
||||
def quote_url(url):
|
||||
def quote_url(url, valid_schemes=("http", "https")):
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https", "gs"):
|
||||
if parsed.scheme not in valid_schemes:
|
||||
return url
|
||||
parsed = parsed._replace(path=quote(parsed.path))
|
||||
return urlunparse(parsed)
|
||||
|
@ -1009,10 +1009,16 @@ class Task(_Task):
|
||||
If None is passed, returns all tasks within the project
|
||||
:param list tags: Filter based on the requested list of tags (strings)
|
||||
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
|
||||
To include All tags (instead of the default All behaviour) use "__$all" as the first string, example:
|
||||
The default behaviour is to join all tags with a logical "OR" operator.
|
||||
To join all tags with a logical "AND" operator instead, use "__$all" as the first string, for example:
|
||||
["__$all", "best", "experiment", "ever"]
|
||||
To combine All tags and exclude a list of tags use "__$not" before the excluded tags, example:
|
||||
["__$all", "best", "experiment", "ever", "__$not", "internal", "test"]
|
||||
To join all tags with AND, but exclude a tag use "__$not" before the excluded tag, for example:
|
||||
["__$all", "best", "experiment", "ever", "__$not", "internal", "__$not", "test"]
|
||||
The "OR" and "AND" operators apply to all tags that follow them until another operator is specified.
|
||||
The NOT operator applies only to the immediately following tag.
|
||||
For example, ["__$all", "a", "b", "c", "__$or", "d", "__$not", "e", "__$and", "__$or" "f", "g"]
|
||||
means ("a" AND "b" AND "c" AND ("d" OR NOT "e") AND ("f" OR "g")).
|
||||
See https://clear.ml/docs/latest/docs/clearml_sdk/task_sdk/#tag-filters for more information.
|
||||
:param list additional_return_fields: Optional, if not provided return a list of Task IDs.
|
||||
If provided return dict per Task with the additional requested fields.
|
||||
Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict:
|
||||
@ -3449,8 +3455,8 @@ class Task(_Task):
|
||||
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||
task_artifacts = task.data.execution.artifacts \
|
||||
if hasattr(task.data.execution, 'artifacts') else None
|
||||
if ((str(task._status) in (
|
||||
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
||||
if ((task._status in (
|
||||
cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
|
||||
or task.output_models_id or (cls.archived_tag in task_tags)
|
||||
or (cls._development_tag not in task_tags)
|
||||
or task_artifacts):
|
||||
@ -3621,7 +3627,10 @@ class Task(_Task):
|
||||
# at least until we support multiple input models
|
||||
# notice that we do not check the task's input model because we allow task reuse and overwrite
|
||||
# add into comment that we are using this model
|
||||
comment = self.comment or ''
|
||||
|
||||
# refresh comment
|
||||
comment = self._reload_field("comment") or self.comment or ''
|
||||
|
||||
if not comment.endswith('\n'):
|
||||
comment += '\n'
|
||||
comment += 'Using model id: {}'.format(model.id)
|
||||
@ -3673,14 +3682,14 @@ class Task(_Task):
|
||||
# noinspection PyProtectedMember
|
||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
||||
|
||||
def _refresh_args_dict(task, config_dict):
|
||||
def _refresh_args_dict(task, config_proxy_dict):
|
||||
# reread from task including newly added keys
|
||||
# noinspection PyProtectedMember
|
||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
|
||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_proxy_dict), prefix=name)
|
||||
# noinspection PyProtectedMember
|
||||
nested_dict = config_dict._to_dict()
|
||||
config_dict.clear()
|
||||
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
|
||||
nested_dict = config_proxy_dict._to_dict()
|
||||
config_proxy_dict.clear()
|
||||
config_proxy_dict._do_update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
|
||||
|
||||
def _check_keys(dict_, warning_sent=False):
|
||||
if warning_sent:
|
||||
@ -4606,15 +4615,15 @@ class Task(_Task):
|
||||
return False
|
||||
|
||||
stopped_statuses = (
|
||||
str(tasks.TaskStatusEnum.stopped),
|
||||
str(tasks.TaskStatusEnum.published),
|
||||
str(tasks.TaskStatusEnum.publishing),
|
||||
str(tasks.TaskStatusEnum.closed),
|
||||
str(tasks.TaskStatusEnum.failed),
|
||||
str(tasks.TaskStatusEnum.completed),
|
||||
cls.TaskStatusEnum.stopped,
|
||||
cls.TaskStatusEnum.published,
|
||||
cls.TaskStatusEnum.publishing,
|
||||
cls.TaskStatusEnum.closed,
|
||||
cls.TaskStatusEnum.failed,
|
||||
cls.TaskStatusEnum.completed,
|
||||
)
|
||||
|
||||
if str(task.status) not in stopped_statuses:
|
||||
if task.status not in stopped_statuses:
|
||||
cls._send(
|
||||
cls._get_default_session(),
|
||||
tasks.StoppedRequest(
|
||||
|
@ -61,6 +61,23 @@ def config_dict_to_text(config):
|
||||
try:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
def raise_on_special_key(config_):
|
||||
if not isinstance(config_, dict):
|
||||
return
|
||||
special_chars = "$}[]:=+#`^?!@*&."
|
||||
for key in config_.keys():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if any(key_char in special_chars for key_char in key):
|
||||
raise ValueError(
|
||||
"Configuration dictionary keys cannot contain any of the following characters: {}".format(
|
||||
special_chars
|
||||
)
|
||||
)
|
||||
for val in config_.values():
|
||||
raise_on_special_key(val)
|
||||
# will fall back to json+pyhocon
|
||||
raise_on_special_key(config)
|
||||
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config)))
|
||||
except Exception:
|
||||
# fallback json+pyhocon
|
||||
|
@ -39,10 +39,14 @@ class ProxyDictPostWrite(dict):
|
||||
return a_dict
|
||||
|
||||
def update(self, E=None, **F):
|
||||
res = self._do_update(E, **F)
|
||||
self._set_callback()
|
||||
return res
|
||||
|
||||
def _do_update(self, E=None, **F):
|
||||
res = super(ProxyDictPostWrite, self).update(
|
||||
ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
|
||||
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
|
||||
self._set_callback()
|
||||
return res
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '1.11.1rc1'
|
||||
__version__ = '1.11.2rc0'
|
||||
|
@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from clearml import Task, Logger
|
||||
|
||||
|
||||
@ -51,7 +54,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
|
||||
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
|
||||
def test(args, model, device, test_loader, epoch):
|
||||
@ -127,8 +130,9 @@ def main():
|
||||
task.execute_remotely(queue_name="default")
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(args, model, device, test_loader, epoch)
|
||||
if (args.save_model):
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||
|
||||
if args.save_model:
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -8,6 +8,6 @@ trainer:
|
||||
filename: best
|
||||
save_last: False
|
||||
save_top_k: 1
|
||||
monitor: loss
|
||||
monitor: train_loss
|
||||
mode: min
|
||||
max_epochs: 10
|
||||
|
@ -78,6 +78,7 @@ class ImageClassifier(LightningModule):
|
||||
x, y = batch
|
||||
logits = self.forward(x)
|
||||
loss = F.nll_loss(logits, y.long())
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
|
@ -8,5 +8,6 @@ trainer:
|
||||
filename: best
|
||||
save_last: False
|
||||
save_top_k: 1
|
||||
monitor: loss
|
||||
mode: min
|
||||
monitor: train_loss
|
||||
mode: min
|
||||
max_epochs: 3
|
||||
|
@ -156,7 +156,7 @@ def main(_):
|
||||
test(FLAGS, model, device, test_loader, epoch)
|
||||
|
||||
if FLAGS.save_model:
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,6 +1,7 @@
|
||||
# ClearML - Example of Pytorch mnist training integration
|
||||
#
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
@ -47,7 +48,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
|
||||
"train", "loss", iteration=(epoch * len(train_loader) + batch_idx), value=loss.item())
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
|
||||
def test(args, model, device, test_loader, epoch):
|
||||
@ -128,7 +129,7 @@ def main():
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(args, model, device, test_loader, epoch)
|
||||
|
||||
if (args.save_model):
|
||||
if args.save_model:
|
||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||
|
||||
|
||||
|
@ -7,3 +7,4 @@ tensorboardX
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
tqdm
|
||||
protobuf>=4.21.1
|
||||
|
Loading…
Reference in New Issue
Block a user