mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Merge branch 'master' of https://github.com/allegroai/clearml
This commit is contained in:
commit
999c6543b8
README.md
clearml
automation
backend_api
backend_interface/task
binding
cli
datasets
logger.pystorage
task.pyutilities
version.pyexamples
@ -76,7 +76,7 @@ Instrumenting these components is the **ClearML-server**, see [Self-Hosting](htt
|
|||||||
* Full source control info including non-committed local changes
|
* Full source control info including non-committed local changes
|
||||||
* Execution environment (including specific packages & versions)
|
* Execution environment (including specific packages & versions)
|
||||||
* Hyper-parameters
|
* Hyper-parameters
|
||||||
* ArgParser/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
|
* [`argparse`](https://docs.python.org/3/library/argparse.html)/[Click](https://github.com/pallets/click/)/[PythonFire](https://github.com/google/python-fire) for command line parameters with currently used values
|
||||||
* Explicit parameters dictionary
|
* Explicit parameters dictionary
|
||||||
* Tensorflow Defines (absl-py)
|
* Tensorflow Defines (absl-py)
|
||||||
* [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides
|
* [Hydra](https://github.com/facebookresearch/hydra) configuration and overrides
|
||||||
|
@ -1615,7 +1615,7 @@ class PipelineController(object):
|
|||||||
|
|
||||||
def _has_stored_configuration(self):
|
def _has_stored_configuration(self):
|
||||||
"""
|
"""
|
||||||
Return True if we are running remotely and we have stored configuration on the Task
|
Return True if we are running remotely, and we have stored configuration on the Task
|
||||||
"""
|
"""
|
||||||
if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task():
|
if self._auto_connect_task and self._task and not self._task.running_locally() and self._task.is_main_task():
|
||||||
stored_config = self._task.get_configuration_object(self._config_section)
|
stored_config = self._task.get_configuration_object(self._config_section)
|
||||||
@ -1698,9 +1698,10 @@ class PipelineController(object):
|
|||||||
conformed_monitors = [
|
conformed_monitors = [
|
||||||
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
|
pair if isinstance(pair[0], (list, tuple)) else (pair, pair) for pair in monitors
|
||||||
]
|
]
|
||||||
# verify pair of pairs
|
# verify the pair of pairs
|
||||||
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
|
if not all(isinstance(x[0][0], str) and isinstance(x[0][1], str) and
|
||||||
isinstance(x[1][0], str) and isinstance(x[1][1], str) for x in conformed_monitors):
|
isinstance(x[1][0], str) and isinstance(x[1][1], str)
|
||||||
|
for x in conformed_monitors):
|
||||||
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
raise ValueError("{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||||
else:
|
else:
|
||||||
# verify a list of tuples
|
# verify a list of tuples
|
||||||
@ -1711,8 +1712,10 @@ class PipelineController(object):
|
|||||||
conformed_monitors = [
|
conformed_monitors = [
|
||||||
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
|
pair if isinstance(pair, (list, tuple)) else (pair, pair) for pair in monitors
|
||||||
]
|
]
|
||||||
# verify pair of pairs
|
# verify the pair of pairs
|
||||||
if not all(isinstance(x[0], str) and isinstance(x[1], str) for x in conformed_monitors):
|
if not all(isinstance(x[0], str) and
|
||||||
|
isinstance(x[1], str)
|
||||||
|
for x in conformed_monitors):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
"{} should be a list of tuples, found: {}".format(monitor_type, monitors))
|
||||||
|
|
||||||
@ -1782,7 +1785,7 @@ class PipelineController(object):
|
|||||||
# type: (...) -> bool
|
# type: (...) -> bool
|
||||||
"""
|
"""
|
||||||
Create a Task from a function, including wrapping the function input arguments
|
Create a Task from a function, including wrapping the function input arguments
|
||||||
into the hyper-parameter section as kwargs, and storing function results as named artifacts
|
into the hyperparameter section as kwargs, and storing function results as named artifacts
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -1874,7 +1877,7 @@ class PipelineController(object):
|
|||||||
:param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop
|
:param continue_on_fail: (default False). If True, failed step will not cause the pipeline to stop
|
||||||
(or marked as failed). Notice, that steps that are connected (or indirectly connected)
|
(or marked as failed). Notice, that steps that are connected (or indirectly connected)
|
||||||
to the failed step will be skipped.
|
to the failed step will be skipped.
|
||||||
:param pre_execute_callback: Callback function, called when the step (Task) is created
|
:param pre_execute_callback: Callback function, called when the step (Task) is created,
|
||||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||||
@ -2302,7 +2305,8 @@ class PipelineController(object):
|
|||||||
|
|
||||||
:param node_params: list of node parameters
|
:param node_params: list of node parameters
|
||||||
:param visited: list of nodes
|
:param visited: list of nodes
|
||||||
:return: Table as List of List of strings (cell)
|
|
||||||
|
:return: Table as a List of a List of strings (cell)
|
||||||
"""
|
"""
|
||||||
task_link_template = self._task.get_output_log_web_page() \
|
task_link_template = self._task.get_output_log_web_page() \
|
||||||
.replace('/{}/'.format(self._task.project), '/{project}/') \
|
.replace('/{}/'.format(self._task.project), '/{project}/') \
|
||||||
@ -2778,6 +2782,7 @@ class PipelineController(object):
|
|||||||
return the pipeline components target folder name/id
|
return the pipeline components target folder name/id
|
||||||
|
|
||||||
:param return_project_id: if False (default), return target folder name. If True, return project id
|
:param return_project_id: if False (default), return target folder name. If True, return project id
|
||||||
|
|
||||||
:return: project id/name (None if not valid)
|
:return: project id/name (None if not valid)
|
||||||
"""
|
"""
|
||||||
if not self._target_project:
|
if not self._target_project:
|
||||||
@ -3110,7 +3115,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
||||||
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
||||||
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
||||||
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
|
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
|
||||||
was specifically defined with "continue_on_fail=True".
|
was specifically defined with "continue_on_fail=True".
|
||||||
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
||||||
and mark the pipeline as failed.
|
and mark the pipeline as failed.
|
||||||
@ -3540,7 +3545,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
|
|
||||||
:param _func: wrapper function
|
:param _func: wrapper function
|
||||||
:param return_values: Provide a list of names for all the results.
|
:param return_values: Provide a list of names for all the results.
|
||||||
Notice! If not provided no results will be stored as artifacts.
|
Notice! If not provided, no results will be stored as artifacts.
|
||||||
:param name: Optional, set the name of the pipeline component task.
|
:param name: Optional, set the name of the pipeline component task.
|
||||||
If not provided, the wrapped function name is used as the pipeline component name
|
If not provided, the wrapped function name is used as the pipeline component name
|
||||||
:param cache: If True, before launching the new step,
|
:param cache: If True, before launching the new step,
|
||||||
@ -3623,7 +3628,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
# allow up to 5 retries (total of 6 runs)
|
# allow up to 5 retries (total of 6 runs)
|
||||||
return retries < 5
|
return retries < 5
|
||||||
|
|
||||||
:param pre_execute_callback: Callback function, called when the step (Task) is created
|
:param pre_execute_callback: Callback function, called when the step (Task) is created,
|
||||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||||
@ -3644,7 +3649,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
:param post_execute_callback: Callback function, called when a step (Task) is completed
|
:param post_execute_callback: Callback function, called when a step (Task) is completed
|
||||||
and other jobs are executed. Allows a user to modify the Task status after completion.
|
and other jobs are going to be executed. Allows a user to modify the Task status after completion.
|
||||||
|
|
||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
|
|
||||||
@ -3794,7 +3799,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
if target_queue:
|
if target_queue:
|
||||||
PipelineDecorator.set_default_execution_queue(target_queue)
|
PipelineDecorator.set_default_execution_queue(target_queue)
|
||||||
else:
|
else:
|
||||||
# if we are are not running from a queue, we are probably in debug mode
|
# if we are not running from a queue, we are probably in debug mode
|
||||||
a_pipeline._clearml_job_class = LocalClearmlJob
|
a_pipeline._clearml_job_class = LocalClearmlJob
|
||||||
a_pipeline._default_execution_queue = 'mock'
|
a_pipeline._default_execution_queue = 'mock'
|
||||||
|
|
||||||
@ -3957,7 +3962,7 @@ class PipelineDecorator(PipelineController):
|
|||||||
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
||||||
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
:param bool abort_on_failure: If False (default), failed pipeline steps will not cause the pipeline
|
||||||
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
to stop immediately, instead any step that is not connected (or indirectly connected) to the failed step,
|
||||||
will still be executed. Nonetheless the pipeline itself will be marked failed, unless the failed step
|
will still be executed. Nonetheless, the pipeline itself will be marked failed, unless the failed step
|
||||||
was specifically defined with "continue_on_fail=True".
|
was specifically defined with "continue_on_fail=True".
|
||||||
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
If True, any failed step will cause the pipeline to immediately abort, stop all running steps,
|
||||||
and mark the pipeline as failed.
|
and mark the pipeline as failed.
|
||||||
|
@ -384,9 +384,10 @@ class BaseJob(object):
|
|||||||
section_overrides=None,
|
section_overrides=None,
|
||||||
params_override=None,
|
params_override=None,
|
||||||
configurations_override=None,
|
configurations_override=None,
|
||||||
explicit_docker_image=None
|
explicit_docker_image=None,
|
||||||
|
account_for_artifacts_hashes=True
|
||||||
):
|
):
|
||||||
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str]) -> Optional[str]
|
# type: (Task, Optional[dict], Optional[dict], Optional[dict], Optional[str], bool) -> Optional[str]
|
||||||
"""
|
"""
|
||||||
Create Hash (str) representing the state of the Task
|
Create Hash (str) representing the state of the Task
|
||||||
|
|
||||||
@ -397,6 +398,8 @@ class BaseJob(object):
|
|||||||
:param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem)
|
:param configurations_override: dictionary of configuration override objects (tasks.ConfigurationItem)
|
||||||
:param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image
|
:param explicit_docker_image: The explicit docker image. Used to invalidate the hash when the docker image
|
||||||
was explicitly changed
|
was explicitly changed
|
||||||
|
:param account_for_artifacts_hashes: Calculate the hash of the task by accounting for the hashes of the
|
||||||
|
artifacts in `kwargs_artifacts` (as opposed of the task ID/artifact name stored in this section)
|
||||||
|
|
||||||
:return: str hash of the Task configuration
|
:return: str hash of the Task configuration
|
||||||
"""
|
"""
|
||||||
@ -416,7 +419,23 @@ class BaseJob(object):
|
|||||||
# we need to ignore `requirements` section because ir might be changing from run to run
|
# we need to ignore `requirements` section because ir might be changing from run to run
|
||||||
script.pop("requirements", None)
|
script.pop("requirements", None)
|
||||||
|
|
||||||
hyper_params = task.get_parameters() if params_override is None else params_override
|
hyper_params = deepcopy(task.get_parameters() if params_override is None else params_override)
|
||||||
|
if account_for_artifacts_hashes:
|
||||||
|
hyper_params_to_change = {}
|
||||||
|
task_cache = {}
|
||||||
|
for key, value in hyper_params.items():
|
||||||
|
if key.startswith("kwargs_artifacts/"):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# key format is <task_id>.<artifact_name>
|
||||||
|
task_id, artifact = value.split(".", 1)
|
||||||
|
task_ = task_cache.setdefault(task_id, Task.get_task(task_id))
|
||||||
|
# set the value of the hyper parameter to the hash of the artifact
|
||||||
|
# because the task ID might differ, but the artifact might be the same
|
||||||
|
hyper_params_to_change[key] = task_.artifacts[artifact].hash
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
hyper_params.update(hyper_params_to_change)
|
||||||
configs = task.get_configuration_objects() if configurations_override is None else configurations_override
|
configs = task.get_configuration_objects() if configurations_override is None else configurations_override
|
||||||
# currently we do not add the docker image to the hash (only args and setup script),
|
# currently we do not add the docker image to the hash (only args and setup script),
|
||||||
# because default docker image will cause the step to change
|
# because default docker image will cause the step to change
|
||||||
@ -585,6 +604,14 @@ class ClearmlJob(BaseJob):
|
|||||||
if allow_caching:
|
if allow_caching:
|
||||||
# look for a cached copy of the Task
|
# look for a cached copy of the Task
|
||||||
# get parameters + task_overrides + as dict and hash it.
|
# get parameters + task_overrides + as dict and hash it.
|
||||||
|
task_hash_legacy = self._create_task_hash(
|
||||||
|
base_temp_task,
|
||||||
|
section_overrides=sections,
|
||||||
|
params_override=task_params,
|
||||||
|
configurations_override=configuration_overrides or None,
|
||||||
|
explicit_docker_image=kwargs.get("explicit_docker_image"),
|
||||||
|
account_for_artifacts_hashes=False
|
||||||
|
)
|
||||||
task_hash = self._create_task_hash(
|
task_hash = self._create_task_hash(
|
||||||
base_temp_task,
|
base_temp_task,
|
||||||
section_overrides=sections,
|
section_overrides=sections,
|
||||||
@ -592,7 +619,7 @@ class ClearmlJob(BaseJob):
|
|||||||
configurations_override=configuration_overrides or None,
|
configurations_override=configuration_overrides or None,
|
||||||
explicit_docker_image=kwargs.get("explicit_docker_image")
|
explicit_docker_image=kwargs.get("explicit_docker_image")
|
||||||
)
|
)
|
||||||
task = self._get_cached_task(task_hash)
|
task = self._get_cached_task(task_hash_legacy) or self._get_cached_task(task_hash)
|
||||||
# if we found a task, just use
|
# if we found a task, just use
|
||||||
if task:
|
if task:
|
||||||
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:
|
if disable_clone_task and self.task and self.task.status == self.task.TaskStatusEnum.created:
|
||||||
|
@ -4590,7 +4590,7 @@ class UpdateResponse(Response):
|
|||||||
|
|
||||||
class ValidateDeleteRequest(Request):
|
class ValidateDeleteRequest(Request):
|
||||||
"""
|
"""
|
||||||
Validates that the project existis and can be deleted
|
Validates that the project exists and can be deleted
|
||||||
|
|
||||||
:param project: Project ID
|
:param project: Project ID
|
||||||
:type project: str
|
:type project: str
|
||||||
|
@ -4616,7 +4616,7 @@ class UpdateResponse(Response):
|
|||||||
|
|
||||||
class ValidateDeleteRequest(Request):
|
class ValidateDeleteRequest(Request):
|
||||||
"""
|
"""
|
||||||
Validates that the project existis and can be deleted
|
Validates that the project exists and can be deleted
|
||||||
|
|
||||||
:param project: Project ID
|
:param project: Project ID
|
||||||
:type project: str
|
:type project: str
|
||||||
|
@ -156,6 +156,19 @@ class Session(TokenManager):
|
|||||||
|
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_client(cls, client, value, first=True):
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
if not any(True for c in cls._client if c[0] == client):
|
||||||
|
if first:
|
||||||
|
cls._client.insert(0, (client, value))
|
||||||
|
else:
|
||||||
|
cls._client.append((client, value))
|
||||||
|
cls.client = ", ".join("{}-{}".format(*x) for x in cls._client)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _connect(self):
|
def _connect(self):
|
||||||
if self._offline_mode:
|
if self._offline_mode:
|
||||||
return
|
return
|
||||||
@ -219,8 +232,7 @@ class Session(TokenManager):
|
|||||||
if not api_version:
|
if not api_version:
|
||||||
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||||
if token_dict.get('server_version'):
|
if token_dict.get('server_version'):
|
||||||
if not any(True for c in Session._client if c[0] == 'clearml-server'):
|
self.add_client('clearml-server', token_dict.get('server_version'))
|
||||||
Session._client.append(('clearml-server', token_dict.get('server_version'), ))
|
|
||||||
|
|
||||||
Session.max_api_version = Session.api_version = str(api_version)
|
Session.max_api_version = Session.api_version = str(api_version)
|
||||||
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")
|
Session.feature_set = str(token_dict.get('feature_set', self.feature_set) or "basic")
|
||||||
|
@ -644,6 +644,23 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||||
return res.response.task
|
return res.response.task
|
||||||
|
|
||||||
|
def _reload_field(self, field):
|
||||||
|
# type: (str) -> Any
|
||||||
|
""" Reload the task specific field, dot seperated for nesting"""
|
||||||
|
with self._edit_lock:
|
||||||
|
if self._offline_mode:
|
||||||
|
task_object = self._reload()
|
||||||
|
else:
|
||||||
|
res = self.send(tasks.GetAllRequest(id=[self.id], only_fields=[field], search_hidden=True))
|
||||||
|
task_object = res.response.tasks[0]
|
||||||
|
|
||||||
|
for p in field.split("."):
|
||||||
|
task_object = getattr(task_object, p, None)
|
||||||
|
if task_object is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
return task_object
|
||||||
|
|
||||||
def reset(self, set_started_on_success=True, force=False):
|
def reset(self, set_started_on_success=True, force=False):
|
||||||
# type: (bool, bool) -> ()
|
# type: (bool, bool) -> ()
|
||||||
"""
|
"""
|
||||||
@ -752,7 +769,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
def publish(self, ignore_errors=True):
|
def publish(self, ignore_errors=True):
|
||||||
# type: (bool) -> ()
|
# type: (bool) -> ()
|
||||||
""" The signal that this task will be published """
|
""" The signal that this task will be published """
|
||||||
if str(self.status) not in (str(tasks.TaskStatusEnum.stopped), str(tasks.TaskStatusEnum.completed)):
|
if self.status not in (self.TaskStatusEnum.stopped, self.TaskStatusEnum.completed):
|
||||||
raise ValueError("Can't publish, Task is not stopped")
|
raise ValueError("Can't publish, Task is not stopped")
|
||||||
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
|
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
|
||||||
assert isinstance(resp.response, tasks.PublishResponse)
|
assert isinstance(resp.response, tasks.PublishResponse)
|
||||||
@ -792,7 +809,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
try:
|
try:
|
||||||
res = self.send(tasks.GetByIdRequest(self.task_id))
|
res = self.send(tasks.GetByIdRequest(self.task_id))
|
||||||
task = res.response.task
|
task = res.response.task
|
||||||
if task.status == Task.TaskStatusEnum.published:
|
if task.status == self.TaskStatusEnum.published:
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
|
raise self.DeleteError("Cannot delete published task {}".format(self.task_id))
|
||||||
self.log.error("Cannot delete published task {}".format(self.task_id))
|
self.log.error("Cannot delete published task {}".format(self.task_id))
|
||||||
@ -2408,7 +2425,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
:return: list of tuples (status, status_message, task_id)
|
:return: list of tuples (status, status_message, task_id)
|
||||||
"""
|
"""
|
||||||
if cls._offline_mode:
|
if cls._offline_mode:
|
||||||
return [(tasks.TaskStatusEnum.created, "offline", i) for i in ids]
|
return [(cls.TaskStatusEnum.created, "offline", i) for i in ids]
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -2547,7 +2564,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
# Since we ae using forced update, make sure he task status is valid
|
# Since we ae using forced update, make sure he task status is valid
|
||||||
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
|
status = self._data.status if self._data and self._reload_skip_flag else self.data.status
|
||||||
if not kwargs.pop("force", False) and \
|
if not kwargs.pop("force", False) and \
|
||||||
status not in (tasks.TaskStatusEnum.created, tasks.TaskStatusEnum.in_progress):
|
status not in (self.TaskStatusEnum.created, self.TaskStatusEnum.in_progress):
|
||||||
# the exception being name/comment that we can always change.
|
# the exception being name/comment that we can always change.
|
||||||
if kwargs and all(
|
if kwargs and all(
|
||||||
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
|
k in ("name", "project", "comment", "tags", "system_tags", "runtime") for k in kwargs.keys()
|
||||||
@ -3000,7 +3017,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
|||||||
def _get_task_status(cls, task_id):
|
def _get_task_status(cls, task_id):
|
||||||
# type: (str) -> (Optional[str], Optional[str])
|
# type: (str) -> (Optional[str], Optional[str])
|
||||||
if cls._offline_mode:
|
if cls._offline_mode:
|
||||||
return tasks.TaskStatusEnum.created, 'offline'
|
return cls.TaskStatusEnum.created, 'offline'
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
@ -17,6 +17,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
_checkpoint_filename = {}
|
_checkpoint_filename = {}
|
||||||
__patched = None
|
__patched = None
|
||||||
__patched_lightning = None
|
__patched_lightning = None
|
||||||
|
__patched_pytorch_lightning = None
|
||||||
__patched_mmcv = None
|
__patched_mmcv = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -26,9 +27,11 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
return
|
return
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
PatchPyTorchModelIO._patch_model_io()
|
||||||
PatchPyTorchModelIO._patch_lightning_io()
|
PatchPyTorchModelIO._patch_lightning_io()
|
||||||
|
PatchPyTorchModelIO._patch_pytorch_lightning_io()
|
||||||
PatchPyTorchModelIO._patch_mmcv()
|
PatchPyTorchModelIO._patch_mmcv()
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||||
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_lightning_io)
|
PostImportHookPatching.add_on_import('lightning', PatchPyTorchModelIO._patch_lightning_io)
|
||||||
|
PostImportHookPatching.add_on_import('pytorch_lightning', PatchPyTorchModelIO._patch_pytorch_lightning_io)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_io():
|
def _patch_model_io():
|
||||||
@ -110,11 +113,57 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
if PatchPyTorchModelIO.__patched_lightning:
|
if PatchPyTorchModelIO.__patched_lightning:
|
||||||
return
|
return
|
||||||
|
|
||||||
if 'pytorch_lightning' not in sys.modules:
|
if 'lightning' not in sys.modules:
|
||||||
return
|
return
|
||||||
|
|
||||||
PatchPyTorchModelIO.__patched_lightning = True
|
PatchPyTorchModelIO.__patched_lightning = True
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import lightning # noqa
|
||||||
|
|
||||||
|
lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
|
||||||
|
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
lightning.pytorch.trainer.Trainer.restore = _patched_call(
|
||||||
|
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
|
||||||
|
) # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
import lightning # noqa
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
|
||||||
|
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
|
||||||
|
PatchPyTorchModelIO._save,
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
|
||||||
|
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
|
||||||
|
PatchPyTorchModelIO._load_from_obj,
|
||||||
|
) # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_pytorch_lightning_io():
|
||||||
|
if PatchPyTorchModelIO.__patched_pytorch_lightning:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'pytorch_lightning' not in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
PatchPyTorchModelIO.__patched_pytorch_lightning = True
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import pytorch_lightning # noqa
|
import pytorch_lightning # noqa
|
||||||
|
@ -3,6 +3,7 @@ from logging import getLogger
|
|||||||
from .frameworks import _patched_call # noqa
|
from .frameworks import _patched_call # noqa
|
||||||
from .import_bind import PostImportHookPatching
|
from .import_bind import PostImportHookPatching
|
||||||
from ..utilities.networking import get_private_ip
|
from ..utilities.networking import get_private_ip
|
||||||
|
from ..config import running_remotely
|
||||||
|
|
||||||
|
|
||||||
class PatchGradio:
|
class PatchGradio:
|
||||||
@ -11,7 +12,7 @@ class PatchGradio:
|
|||||||
|
|
||||||
_default_gradio_address = "0.0.0.0"
|
_default_gradio_address = "0.0.0.0"
|
||||||
_default_gradio_port = 7860
|
_default_gradio_port = 7860
|
||||||
_root_path_format = "/service/{}"
|
_root_path_format = "/service/{}/"
|
||||||
__server_config_warning = set()
|
__server_config_warning = set()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -32,42 +33,55 @@ class PatchGradio:
|
|||||||
try:
|
try:
|
||||||
import gradio
|
import gradio
|
||||||
|
|
||||||
gradio.networking.start_server = _patched_call(
|
gradio.routes.App.get_blocks = _patched_call(gradio.routes.App.get_blocks, PatchGradio._patched_get_blocks)
|
||||||
gradio.networking.start_server, PatchGradio._patched_start_server
|
gradio.blocks.Blocks.launch = _patched_call(gradio.blocks.Blocks.launch, PatchGradio._patched_launch)
|
||||||
)
|
|
||||||
gradio.routes.App.__init__ = _patched_call(gradio.routes.App.__init__, PatchGradio._patched_init)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
cls.__patched = True
|
cls.__patched = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_start_server(original_fn, self, server_name=None, server_port=None, *args, **kwargs):
|
def _patched_get_blocks(original_fn, *args, **kwargs):
|
||||||
|
blocks = original_fn(*args, **kwargs)
|
||||||
|
if not PatchGradio._current_task or not running_remotely():
|
||||||
|
return blocks
|
||||||
|
blocks.config["root"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||||
|
blocks.root = blocks.config["root"]
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_launch(original_fn, *args, **kwargs):
|
||||||
if not PatchGradio._current_task:
|
if not PatchGradio._current_task:
|
||||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
return original_fn(*args, **kwargs)
|
||||||
|
PatchGradio.__warn_on_server_config(
|
||||||
|
kwargs.get("server_name"),
|
||||||
|
kwargs.get("server_port"),
|
||||||
|
kwargs.get("root_path")
|
||||||
|
)
|
||||||
|
if not running_remotely():
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
|
# noinspection PyProtectedMember
|
||||||
PatchGradio._current_task._set_runtime_properties(
|
PatchGradio._current_task._set_runtime_properties(
|
||||||
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
{"_SERVICE": "EXTERNAL", "_ADDRESS": get_private_ip(), "_PORT": PatchGradio._default_gradio_port}
|
||||||
)
|
)
|
||||||
PatchGradio._current_task.set_system_tags(["external_service"])
|
PatchGradio._current_task.set_system_tags(["external_service"])
|
||||||
PatchGradio.__warn_on_server_config(server_name, server_port)
|
|
||||||
server_name = PatchGradio._default_gradio_address
|
|
||||||
server_port = PatchGradio._default_gradio_port
|
|
||||||
return original_fn(self, server_name, server_port, *args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patched_init(original_fn, *args, **kwargs):
|
|
||||||
if not PatchGradio._current_task:
|
|
||||||
return original_fn(*args, **kwargs)
|
|
||||||
PatchGradio.__warn_on_server_config(kwargs.get("server_name"), kwargs.get("server_port"))
|
|
||||||
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
|
||||||
kwargs["root_path_in_servers"] = False
|
|
||||||
kwargs["server_name"] = PatchGradio._default_gradio_address
|
kwargs["server_name"] = PatchGradio._default_gradio_address
|
||||||
kwargs["server_port"] = PatchGradio._default_gradio_port
|
kwargs["server_port"] = PatchGradio._default_gradio_port
|
||||||
|
kwargs["root_path"] = PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
del kwargs["root_path"]
|
||||||
return original_fn(*args, **kwargs)
|
return original_fn(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __warn_on_server_config(cls, server_name, server_port):
|
def __warn_on_server_config(cls, server_name, server_port, root_path):
|
||||||
if server_name is None and server_port is None:
|
if (server_name is None or server_name == PatchGradio._default_gradio_address) and \
|
||||||
|
(server_port is None and server_port == PatchGradio._default_gradio_port):
|
||||||
return
|
return
|
||||||
|
if (server_name, server_port, root_path) in cls.__server_config_warning:
|
||||||
|
return
|
||||||
|
cls.__server_config_warning.add((server_name, server_port, root_path))
|
||||||
if server_name is not None and server_port is not None:
|
if server_name is not None and server_port is not None:
|
||||||
server_config = "{}:{}".format(server_name, server_port)
|
server_config = "{}:{}".format(server_name, server_port)
|
||||||
what_to_ignore = "name and port"
|
what_to_ignore = "name and port"
|
||||||
@ -77,11 +91,14 @@ class PatchGradio:
|
|||||||
else:
|
else:
|
||||||
server_config = str(server_port)
|
server_config = str(server_port)
|
||||||
what_to_ignore = "port"
|
what_to_ignore = "port"
|
||||||
if server_config in cls.__server_config_warning:
|
|
||||||
return
|
|
||||||
cls.__server_config_warning.add(server_config)
|
|
||||||
getLogger().warning(
|
getLogger().warning(
|
||||||
"ClearML only supports '{}:{}'as the Gradio server. Ignoring {} '{}'".format(
|
"ClearML only supports '{}:{}' as the Gradio server. Ignoring {} '{}' in remote execution".format(
|
||||||
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
PatchGradio._default_gradio_address, PatchGradio._default_gradio_port, what_to_ignore, server_config
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if root_path is not None:
|
||||||
|
getLogger().warning(
|
||||||
|
"ClearML will override root_path '{}' to '{}' in remote execution".format(
|
||||||
|
root_path, PatchGradio._root_path_format.format(PatchGradio._current_task.id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -209,7 +209,10 @@ class PatchJsonArgParse(object):
|
|||||||
with change_to_path_dir(path):
|
with change_to_path_dir(path):
|
||||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||||
if subcommand:
|
if subcommand:
|
||||||
parsed_cfg = {subcommand + PatchJsonArgParse._commands_sep + k: v for k, v in parsed_cfg.items()}
|
parsed_cfg = {
|
||||||
|
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
|
||||||
|
for k, v in parsed_cfg.items()
|
||||||
|
}
|
||||||
return parsed_cfg
|
return parsed_cfg
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -237,6 +237,12 @@ def parse_known_host(parsed_host):
|
|||||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008', 1) + parsed_host.path
|
||||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8081', 1) + parsed_host.path
|
||||||
|
elif parsed_host.port is None:
|
||||||
|
print('Web app hosted on standard port using ' + parsed_host.scheme + ' protocol.')
|
||||||
|
print('Assuming files and api ports are unchanged and use the same (' + parsed_host.scheme + ') protocol')
|
||||||
|
api_host = parsed_host.scheme + "://" + parsed_host.netloc + ':8008' + parsed_host.path
|
||||||
|
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||||
|
files_host = parsed_host.scheme + "://" + parsed_host.netloc+ ':8081' + parsed_host.path
|
||||||
else:
|
else:
|
||||||
print("Warning! Could not parse host name")
|
print("Warning! Could not parse host name")
|
||||||
api_host = ''
|
api_host = ''
|
||||||
|
@ -7,7 +7,11 @@ from typing import Sequence
|
|||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
import clearml.backend_api.session
|
||||||
from clearml.datasets import Dataset
|
from clearml.datasets import Dataset
|
||||||
|
from clearml.version import __version__
|
||||||
|
|
||||||
|
clearml.backend_api.session.Session.add_client("clearml-data", __version__)
|
||||||
|
|
||||||
|
|
||||||
def check_null_id(args):
|
def check_null_id(args):
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from argparse import ArgumentParser, RawTextHelpFormatter
|
from argparse import ArgumentParser, RawTextHelpFormatter
|
||||||
|
|
||||||
|
import clearml.backend_api.session
|
||||||
|
from clearml import Task
|
||||||
from clearml.automation import (
|
from clearml.automation import (
|
||||||
DiscreteParameterRange,
|
DiscreteParameterRange,
|
||||||
UniformIntegerParameterRange,
|
UniformIntegerParameterRange,
|
||||||
@ -11,8 +13,11 @@ from clearml.automation import (
|
|||||||
RandomSearch,
|
RandomSearch,
|
||||||
GridSearch,
|
GridSearch,
|
||||||
)
|
)
|
||||||
from clearml import Task
|
|
||||||
from clearml.backend_interface.task.populate import CreateAndPopulate
|
from clearml.backend_interface.task.populate import CreateAndPopulate
|
||||||
|
from clearml.version import __version__
|
||||||
|
|
||||||
|
clearml.backend_api.session.Session.add_client("clearml-param-search", __version__)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from clearml.automation.optuna import OptimizerOptuna # noqa
|
from clearml.automation.optuna import OptimizerOptuna # noqa
|
||||||
|
@ -3,9 +3,12 @@ from argparse import ArgumentParser
|
|||||||
|
|
||||||
from pathlib2 import Path
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
import clearml.backend_api.session
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
from clearml.version import __version__
|
|
||||||
from clearml.backend_interface.task.populate import CreateAndPopulate
|
from clearml.backend_interface.task.populate import CreateAndPopulate
|
||||||
|
from clearml.version import __version__
|
||||||
|
|
||||||
|
clearml.backend_api.session.Session.add_client("clearml-task", __version__)
|
||||||
|
|
||||||
|
|
||||||
def setup_parser(parser):
|
def setup_parser(parser):
|
||||||
|
@ -1877,6 +1877,7 @@ class Dataset(object):
|
|||||||
ids=None, # type: Optional[Sequence[str]]
|
ids=None, # type: Optional[Sequence[str]]
|
||||||
only_completed=True, # type: bool
|
only_completed=True, # type: bool
|
||||||
recursive_project_search=True, # type: bool
|
recursive_project_search=True, # type: bool
|
||||||
|
include_archived=True, # type: bool
|
||||||
):
|
):
|
||||||
# type: (...) -> List[dict]
|
# type: (...) -> List[dict]
|
||||||
"""
|
"""
|
||||||
@ -1890,9 +1891,16 @@ class Dataset(object):
|
|||||||
:param recursive_project_search: If True and the `dataset_project` argument is set,
|
:param recursive_project_search: If True and the `dataset_project` argument is set,
|
||||||
search inside subprojects as well.
|
search inside subprojects as well.
|
||||||
If False, don't search inside subprojects (except for the special `.datasets` subproject)
|
If False, don't search inside subprojects (except for the special `.datasets` subproject)
|
||||||
|
:param include_archived: If True, include archived datasets as well.
|
||||||
:return: List of dictionaries with dataset information
|
:return: List of dictionaries with dataset information
|
||||||
Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},]
|
Example: [{'name': name, 'project': project name, 'id': dataset_id, 'created': date_created},]
|
||||||
"""
|
"""
|
||||||
|
# if include_archived is False, we need to add the system tag __$not:archived to filter out archived datasets
|
||||||
|
if not include_archived:
|
||||||
|
system_tags = ["__$all", cls.__tag, "__$not", "archived"]
|
||||||
|
else:
|
||||||
|
system_tags = [cls.__tag]
|
||||||
|
|
||||||
if dataset_project:
|
if dataset_project:
|
||||||
if not recursive_project_search:
|
if not recursive_project_search:
|
||||||
dataset_projects = [
|
dataset_projects = [
|
||||||
@ -1903,12 +1911,13 @@ class Dataset(object):
|
|||||||
dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))]
|
dataset_projects = [exact_match_regex(dataset_project), "^{}/.*".format(re.escape(dataset_project))]
|
||||||
else:
|
else:
|
||||||
dataset_projects = None
|
dataset_projects = None
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
datasets = Task._query_tasks(
|
datasets = Task._query_tasks(
|
||||||
task_ids=ids or None,
|
task_ids=ids or None,
|
||||||
project_name=dataset_projects,
|
project_name=dataset_projects,
|
||||||
task_name=partial_name,
|
task_name=partial_name,
|
||||||
system_tags=[cls.__tag],
|
system_tags=system_tags,
|
||||||
type=[str(Task.TaskTypes.data_processing)],
|
type=[str(Task.TaskTypes.data_processing)],
|
||||||
tags=tags or None,
|
tags=tags or None,
|
||||||
status=["stopped", "published", "completed", "closed"] if only_completed else None,
|
status=["stopped", "published", "completed", "closed"] if only_completed else None,
|
||||||
@ -2278,7 +2287,8 @@ class Dataset(object):
|
|||||||
ds = Dataset.get(dependency)
|
ds = Dataset.get(dependency)
|
||||||
links.update(ds._dataset_link_entries)
|
links.update(ds._dataset_link_entries)
|
||||||
links.update(self._dataset_link_entries)
|
links.update(self._dataset_link_entries)
|
||||||
def _download_link(link,target_path):
|
|
||||||
|
def _download_link(link, target_path):
|
||||||
if os.path.exists(target_path):
|
if os.path.exists(target_path):
|
||||||
LoggerRoot.get_base_logger().info(
|
LoggerRoot.get_base_logger().info(
|
||||||
"{} already exists. Skipping downloading {}".format(
|
"{} already exists. Skipping downloading {}".format(
|
||||||
@ -2310,16 +2320,12 @@ class Dataset(object):
|
|||||||
if not max_workers:
|
if not max_workers:
|
||||||
for relative_path, link in links.items():
|
for relative_path, link in links.items():
|
||||||
target_path = os.path.join(target_folder, relative_path)
|
target_path = os.path.join(target_folder, relative_path)
|
||||||
_download_link(link,target_path)
|
_download_link(link, target_path)
|
||||||
else:
|
else:
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
for relative_path, link in links.items():
|
for relative_path, link in links.items():
|
||||||
target_path = os.path.join(target_folder, relative_path)
|
target_path = os.path.join(target_folder, relative_path)
|
||||||
pool.submit(_download_link,link,target_path)
|
pool.submit(_download_link, link, target_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_dataset_archive(
|
def _extract_dataset_archive(
|
||||||
self,
|
self,
|
||||||
@ -2732,6 +2738,7 @@ class Dataset(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_dependency_chunk_lookup(self):
|
def _build_dependency_chunk_lookup(self):
|
||||||
# type: () -> Dict[str, int]
|
# type: () -> Dict[str, int]
|
||||||
"""
|
"""
|
||||||
|
@ -199,6 +199,10 @@ class Logger(object):
|
|||||||
"""
|
"""
|
||||||
For explicit reporting, plot a vector as (default stacked) histogram.
|
For explicit reporting, plot a vector as (default stacked) histogram.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This method is the same as :meth:`Logger.report_histogram`.
|
||||||
|
This method is deprecated, use :meth:`Logger.report_histogram` instead.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
@ -224,6 +228,11 @@ class Logger(object):
|
|||||||
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
|
See full details on the supported configuration: https://plotly.com/javascript/reference/layout/
|
||||||
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
|
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
":meth:`Logger.report_vector` is deprecated;"
|
||||||
|
"use :meth:`Logger.report_histogram` instead.",
|
||||||
|
DeprecationWarning
|
||||||
|
)
|
||||||
self._touch_title_series(title, series)
|
self._touch_title_series(title, series)
|
||||||
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
|
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
|
||||||
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
|
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
|
||||||
@ -439,7 +448,16 @@ class Logger(object):
|
|||||||
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
||||||
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
|
See full details on the supported configuration: https://plotly.com/javascript/reference/scatter/
|
||||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This method is the same as :meth:`Logger.report_scatter2d` with :param:`mode='lines'`.
|
||||||
|
This method is deprecated, use :meth:`Logger.report_scatter2d` instead.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
":meth:`Logger.report_line_plot` is deprecated;"
|
||||||
|
"use :meth:`Logger.report_scatter2d` instead, e.g., with :param:`mode='lines'`.",
|
||||||
|
DeprecationWarning
|
||||||
|
)
|
||||||
|
|
||||||
# noinspection PyArgumentList
|
# noinspection PyArgumentList
|
||||||
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
||||||
@ -710,6 +728,7 @@ class Logger(object):
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
This method is the same as :meth:`Logger.report_confusion_matrix`.
|
This method is the same as :meth:`Logger.report_confusion_matrix`.
|
||||||
|
This method is deprecated, use :meth:`Logger.report_confusion_matrix` instead.
|
||||||
|
|
||||||
:param str title: The title (metric) of the plot.
|
:param str title: The title (metric) of the plot.
|
||||||
:param str series: The series name (variant) of the reported confusion matrix.
|
:param str series: The series name (variant) of the reported confusion matrix.
|
||||||
@ -719,11 +738,16 @@ class Logger(object):
|
|||||||
:param str yaxis: The y-axis title. (Optional)
|
:param str yaxis: The y-axis title. (Optional)
|
||||||
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
:param list(str) xlabels: Labels for each column of the matrix. (Optional)
|
||||||
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
:param list(str) ylabels: Labels for each row of the matrix. (Optional)
|
||||||
:param bool yaxis_reversed: If False, 0,0 is at the bottom left corner. If True, 0,0 is at the top left corner
|
:param bool yaxis_reversed: If False, 0,0 is in the bottom left corner. If True, 0,0 is in the top left corner
|
||||||
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
:param dict extra_layout: optional dictionary for layout configuration, passed directly to plotly
|
||||||
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
|
See full details on the supported configuration: https://plotly.com/javascript/reference/heatmap/
|
||||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
":meth:`Logger.report_matrix` is deprecated;"
|
||||||
|
"use :meth:`Logger.report_confusion_matrix` instead.",
|
||||||
|
DeprecationWarning
|
||||||
|
)
|
||||||
self._touch_title_series(title, series)
|
self._touch_title_series(title, series)
|
||||||
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
|
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
|
||||||
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,
|
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,
|
||||||
|
@ -1783,7 +1783,7 @@ class StorageHelper(object):
|
|||||||
Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3)
|
Supports both local and remote files (currently local files, network-mapped files, HTTP/S and Amazon S3)
|
||||||
"""
|
"""
|
||||||
_temp_download_suffix = '.partially'
|
_temp_download_suffix = '.partially'
|
||||||
_quotable_uri_schemes = set(_HttpDriver.schemes) | set([_GoogleCloudStorageDriver.scheme])
|
_quotable_uri_schemes = set(_HttpDriver.schemes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_logger(cls):
|
def _get_logger(cls):
|
||||||
@ -2414,7 +2414,7 @@ class StorageHelper(object):
|
|||||||
|
|
||||||
if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema
|
if self.scheme in StorageHelper._quotable_uri_schemes: # TODO: fix-driver-schema
|
||||||
# quote link
|
# quote link
|
||||||
result_dest_path = quote_url(result_dest_path)
|
result_dest_path = quote_url(result_dest_path, StorageHelper._quotable_uri_schemes)
|
||||||
|
|
||||||
return result_dest_path
|
return result_dest_path
|
||||||
|
|
||||||
@ -2436,7 +2436,7 @@ class StorageHelper(object):
|
|||||||
|
|
||||||
# quote link
|
# quote link
|
||||||
def callback(result):
|
def callback(result):
|
||||||
return a_cb(quote_url(result_path) if result else result)
|
return a_cb(quote_url(result_path, StorageHelper._quotable_uri_schemes) if result else result)
|
||||||
# replace callback with wrapper
|
# replace callback with wrapper
|
||||||
cb = callback
|
cb = callback
|
||||||
|
|
||||||
@ -2463,7 +2463,7 @@ class StorageHelper(object):
|
|||||||
retries=retries,
|
retries=retries,
|
||||||
return_canonized=return_canonized)
|
return_canonized=return_canonized)
|
||||||
if res:
|
if res:
|
||||||
result_path = quote_url(result_path)
|
result_path = quote_url(result_path, StorageHelper._quotable_uri_schemes)
|
||||||
return result_path
|
return result_path
|
||||||
|
|
||||||
def list(self, prefix=None, with_metadata=False):
|
def list(self, prefix=None, with_metadata=False):
|
||||||
|
@ -42,9 +42,9 @@ def get_config_object_matcher(**patterns):
|
|||||||
return _matcher
|
return _matcher
|
||||||
|
|
||||||
|
|
||||||
def quote_url(url):
|
def quote_url(url, valid_schemes=("http", "https")):
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
if parsed.scheme not in ("http", "https", "gs"):
|
if parsed.scheme not in valid_schemes:
|
||||||
return url
|
return url
|
||||||
parsed = parsed._replace(path=quote(parsed.path))
|
parsed = parsed._replace(path=quote(parsed.path))
|
||||||
return urlunparse(parsed)
|
return urlunparse(parsed)
|
||||||
|
@ -1009,10 +1009,16 @@ class Task(_Task):
|
|||||||
If None is passed, returns all tasks within the project
|
If None is passed, returns all tasks within the project
|
||||||
:param list tags: Filter based on the requested list of tags (strings)
|
:param list tags: Filter based on the requested list of tags (strings)
|
||||||
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
|
To exclude a tag add "-" prefix to the tag. Example: ["best", "-debug"]
|
||||||
To include All tags (instead of the default All behaviour) use "__$all" as the first string, example:
|
The default behaviour is to join all tags with a logical "OR" operator.
|
||||||
|
To join all tags with a logical "AND" operator instead, use "__$all" as the first string, for example:
|
||||||
["__$all", "best", "experiment", "ever"]
|
["__$all", "best", "experiment", "ever"]
|
||||||
To combine All tags and exclude a list of tags use "__$not" before the excluded tags, example:
|
To join all tags with AND, but exclude a tag use "__$not" before the excluded tag, for example:
|
||||||
["__$all", "best", "experiment", "ever", "__$not", "internal", "test"]
|
["__$all", "best", "experiment", "ever", "__$not", "internal", "__$not", "test"]
|
||||||
|
The "OR" and "AND" operators apply to all tags that follow them until another operator is specified.
|
||||||
|
The NOT operator applies only to the immediately following tag.
|
||||||
|
For example, ["__$all", "a", "b", "c", "__$or", "d", "__$not", "e", "__$and", "__$or" "f", "g"]
|
||||||
|
means ("a" AND "b" AND "c" AND ("d" OR NOT "e") AND ("f" OR "g")).
|
||||||
|
See https://clear.ml/docs/latest/docs/clearml_sdk/task_sdk/#tag-filters for more information.
|
||||||
:param list additional_return_fields: Optional, if not provided return a list of Task IDs.
|
:param list additional_return_fields: Optional, if not provided return a list of Task IDs.
|
||||||
If provided return dict per Task with the additional requested fields.
|
If provided return dict per Task with the additional requested fields.
|
||||||
Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict:
|
Example: ``returned_fields=['last_updated', 'user', 'script.repository']`` will return a list of dict:
|
||||||
@ -3449,8 +3455,8 @@ class Task(_Task):
|
|||||||
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
task_tags = task.data.system_tags if hasattr(task.data, 'system_tags') else task.data.tags
|
||||||
task_artifacts = task.data.execution.artifacts \
|
task_artifacts = task.data.execution.artifacts \
|
||||||
if hasattr(task.data.execution, 'artifacts') else None
|
if hasattr(task.data.execution, 'artifacts') else None
|
||||||
if ((str(task._status) in (
|
if ((task._status in (
|
||||||
str(tasks.TaskStatusEnum.published), str(tasks.TaskStatusEnum.closed)))
|
cls.TaskStatusEnum.published, cls.TaskStatusEnum.closed))
|
||||||
or task.output_models_id or (cls.archived_tag in task_tags)
|
or task.output_models_id or (cls.archived_tag in task_tags)
|
||||||
or (cls._development_tag not in task_tags)
|
or (cls._development_tag not in task_tags)
|
||||||
or task_artifacts):
|
or task_artifacts):
|
||||||
@ -3621,7 +3627,10 @@ class Task(_Task):
|
|||||||
# at least until we support multiple input models
|
# at least until we support multiple input models
|
||||||
# notice that we do not check the task's input model because we allow task reuse and overwrite
|
# notice that we do not check the task's input model because we allow task reuse and overwrite
|
||||||
# add into comment that we are using this model
|
# add into comment that we are using this model
|
||||||
comment = self.comment or ''
|
|
||||||
|
# refresh comment
|
||||||
|
comment = self._reload_field("comment") or self.comment or ''
|
||||||
|
|
||||||
if not comment.endswith('\n'):
|
if not comment.endswith('\n'):
|
||||||
comment += '\n'
|
comment += '\n'
|
||||||
comment += 'Using model id: {}'.format(model.id)
|
comment += 'Using model id: {}'.format(model.id)
|
||||||
@ -3673,14 +3682,14 @@ class Task(_Task):
|
|||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
task._arguments.copy_from_dict(flatten_dictionary(config_dict), prefix=name)
|
||||||
|
|
||||||
def _refresh_args_dict(task, config_dict):
|
def _refresh_args_dict(task, config_proxy_dict):
|
||||||
# reread from task including newly added keys
|
# reread from task including newly added keys
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_dict), prefix=name)
|
a_flat_dict = task._arguments.copy_to_dict(flatten_dictionary(config_proxy_dict), prefix=name)
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
nested_dict = config_dict._to_dict()
|
nested_dict = config_proxy_dict._to_dict()
|
||||||
config_dict.clear()
|
config_proxy_dict.clear()
|
||||||
config_dict.update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
|
config_proxy_dict._do_update(nested_from_flat_dictionary(nested_dict, a_flat_dict))
|
||||||
|
|
||||||
def _check_keys(dict_, warning_sent=False):
|
def _check_keys(dict_, warning_sent=False):
|
||||||
if warning_sent:
|
if warning_sent:
|
||||||
@ -4606,15 +4615,15 @@ class Task(_Task):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
stopped_statuses = (
|
stopped_statuses = (
|
||||||
str(tasks.TaskStatusEnum.stopped),
|
cls.TaskStatusEnum.stopped,
|
||||||
str(tasks.TaskStatusEnum.published),
|
cls.TaskStatusEnum.published,
|
||||||
str(tasks.TaskStatusEnum.publishing),
|
cls.TaskStatusEnum.publishing,
|
||||||
str(tasks.TaskStatusEnum.closed),
|
cls.TaskStatusEnum.closed,
|
||||||
str(tasks.TaskStatusEnum.failed),
|
cls.TaskStatusEnum.failed,
|
||||||
str(tasks.TaskStatusEnum.completed),
|
cls.TaskStatusEnum.completed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if str(task.status) not in stopped_statuses:
|
if task.status not in stopped_statuses:
|
||||||
cls._send(
|
cls._send(
|
||||||
cls._get_default_session(),
|
cls._get_default_session(),
|
||||||
tasks.StoppedRequest(
|
tasks.StoppedRequest(
|
||||||
|
@ -61,6 +61,23 @@ def config_dict_to_text(config):
|
|||||||
try:
|
try:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
|
def raise_on_special_key(config_):
|
||||||
|
if not isinstance(config_, dict):
|
||||||
|
return
|
||||||
|
special_chars = "$}[]:=+#`^?!@*&."
|
||||||
|
for key in config_.keys():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
continue
|
||||||
|
if any(key_char in special_chars for key_char in key):
|
||||||
|
raise ValueError(
|
||||||
|
"Configuration dictionary keys cannot contain any of the following characters: {}".format(
|
||||||
|
special_chars
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for val in config_.values():
|
||||||
|
raise_on_special_key(val)
|
||||||
|
# will fall back to json+pyhocon
|
||||||
|
raise_on_special_key(config)
|
||||||
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config)))
|
text = HOCONConverter.to_hocon(ConfigFactory.from_dict(hocon_quote_key(config)))
|
||||||
except Exception:
|
except Exception:
|
||||||
# fallback json+pyhocon
|
# fallback json+pyhocon
|
||||||
|
@ -39,10 +39,14 @@ class ProxyDictPostWrite(dict):
|
|||||||
return a_dict
|
return a_dict
|
||||||
|
|
||||||
def update(self, E=None, **F):
|
def update(self, E=None, **F):
|
||||||
|
res = self._do_update(E, **F)
|
||||||
|
self._set_callback()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _do_update(self, E=None, **F):
|
||||||
res = super(ProxyDictPostWrite, self).update(
|
res = super(ProxyDictPostWrite, self).update(
|
||||||
ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
|
ProxyDictPostWrite(self._update_obj, self._set_callback, E) if E is not None else
|
||||||
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
|
ProxyDictPostWrite(self._update_obj, self._set_callback, **F))
|
||||||
self._set_callback()
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '1.11.1rc1'
|
__version__ = '1.11.2rc0'
|
||||||
|
@ -7,14 +7,17 @@ make sure code doesn't crash, and then move to a stronger machine for the entire
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
|
|
||||||
from clearml import Task, Logger
|
from clearml import Task, Logger
|
||||||
|
|
||||||
|
|
||||||
@ -127,8 +130,9 @@ def main():
|
|||||||
task.execute_remotely(queue_name="default")
|
task.execute_remotely(queue_name="default")
|
||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
if (args.save_model):
|
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
if args.save_model:
|
||||||
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_remote.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -8,6 +8,6 @@ trainer:
|
|||||||
filename: best
|
filename: best
|
||||||
save_last: False
|
save_last: False
|
||||||
save_top_k: 1
|
save_top_k: 1
|
||||||
monitor: loss
|
monitor: train_loss
|
||||||
mode: min
|
mode: min
|
||||||
max_epochs: 10
|
max_epochs: 10
|
||||||
|
@ -78,6 +78,7 @@ class ImageClassifier(LightningModule):
|
|||||||
x, y = batch
|
x, y = batch
|
||||||
logits = self.forward(x)
|
logits = self.forward(x)
|
||||||
loss = F.nll_loss(logits, y.long())
|
loss = F.nll_loss(logits, y.long())
|
||||||
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
|
@ -8,5 +8,6 @@ trainer:
|
|||||||
filename: best
|
filename: best
|
||||||
save_last: False
|
save_last: False
|
||||||
save_top_k: 1
|
save_top_k: 1
|
||||||
monitor: loss
|
monitor: train_loss
|
||||||
mode: min
|
mode: min
|
||||||
|
max_epochs: 3
|
||||||
|
@ -156,7 +156,7 @@ def main(_):
|
|||||||
test(FLAGS, model, device, test_loader, epoch)
|
test(FLAGS, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if FLAGS.save_model:
|
if FLAGS.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn_abseil.pt"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# ClearML - Example of Pytorch mnist training integration
|
# ClearML - Example of Pytorch mnist training integration
|
||||||
#
|
#
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from tempfile import gettempdir
|
from tempfile import gettempdir
|
||||||
@ -128,7 +129,7 @@ def main():
|
|||||||
train(args, model, device, train_loader, optimizer, epoch)
|
train(args, model, device, train_loader, optimizer, epoch)
|
||||||
test(args, model, device, test_loader, epoch)
|
test(args, model, device, test_loader, epoch)
|
||||||
|
|
||||||
if (args.save_model):
|
if args.save_model:
|
||||||
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
torch.save(model.state_dict(), os.path.join(gettempdir(), "mnist_cnn.pt"))
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,3 +7,4 @@ tensorboardX
|
|||||||
torch>=1.1.0
|
torch>=1.1.0
|
||||||
torchvision>=0.3.0
|
torchvision>=0.3.0
|
||||||
tqdm
|
tqdm
|
||||||
|
protobuf>=4.21.1
|
||||||
|
Loading…
Reference in New Issue
Block a user