mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Merge branch 'master' of https://github.com/allegroai/clearml
This commit is contained in:
@@ -69,6 +69,7 @@ class PipelineController(object):
|
||||
_final_failure = {} # Node.name: bool
|
||||
_task_template_header = CreateFromFunction.default_task_template_header
|
||||
_default_pipeline_version = "1.0.0"
|
||||
_project_section = ".pipelines"
|
||||
|
||||
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
|
||||
|
||||
@@ -176,7 +177,8 @@ class PipelineController(object):
|
||||
always_create_from_code=True, # type: bool
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@@ -266,6 +268,9 @@ class PipelineController(object):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution when creating
|
||||
the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at
|
||||
the beginning of each step’s execution. Default is False
|
||||
"""
|
||||
if auto_version_bump is not None:
|
||||
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
||||
@@ -302,10 +307,11 @@ class PipelineController(object):
|
||||
self._last_progress_update_time = 0
|
||||
self._artifact_serialization_function = artifact_serialization_function
|
||||
self._artifact_deserialization_function = artifact_deserialization_function
|
||||
self._skip_global_imports = skip_global_imports
|
||||
if not self._task:
|
||||
task_name = name or project or '{}'.format(datetime.now())
|
||||
if self._pipeline_as_sub_project:
|
||||
parent_project = "{}.pipelines".format(project+'/' if project else '')
|
||||
parent_project = (project + "/" if project else "") + self._project_section
|
||||
project_name = "{}/{}".format(parent_project, task_name)
|
||||
else:
|
||||
parent_project = None
|
||||
@@ -1422,7 +1428,7 @@ class PipelineController(object):
|
||||
mutually_exclusive(pipeline_id=pipeline_id, pipeline_project=pipeline_project, _require_at_least_one=False)
|
||||
mutually_exclusive(pipeline_id=pipeline_id, pipeline_name=pipeline_name, _require_at_least_one=False)
|
||||
if not pipeline_id:
|
||||
pipeline_project_hidden = "{}/.pipelines/{}".format(pipeline_project, pipeline_name)
|
||||
pipeline_project_hidden = "{}/{}/{}".format(pipeline_project, cls._project_section, pipeline_name)
|
||||
name_with_runtime_number_regex = r"^{}( #[0-9]+)*$".format(re.escape(pipeline_name))
|
||||
pipelines = Task._query_tasks(
|
||||
pipeline_project=[pipeline_project_hidden],
|
||||
@@ -1520,7 +1526,8 @@ class PipelineController(object):
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||
skip_global_imports=self._skip_global_imports
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@@ -2725,7 +2732,7 @@ class PipelineController(object):
|
||||
self._final_failure[node.name] = True
|
||||
|
||||
completed_jobs.append(j)
|
||||
node.executed = node.job.task_id() if not node_failed else False
|
||||
node.executed = node.job.task_id() if not (node_failed or node.job.is_aborted()) else False
|
||||
if j in launched_nodes:
|
||||
launched_nodes.remove(j)
|
||||
# check if we need to stop all running steps
|
||||
@@ -3315,7 +3322,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@@ -3398,6 +3406,9 @@ class PipelineDecorator(PipelineController):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||
Default is False
|
||||
"""
|
||||
super(PipelineDecorator, self).__init__(
|
||||
name=name,
|
||||
@@ -3419,7 +3430,8 @@ class PipelineDecorator(PipelineController):
|
||||
always_create_from_code=False,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
|
||||
# if we are in eager execution, make sure parent class knows it
|
||||
@@ -3482,7 +3494,7 @@ class PipelineDecorator(PipelineController):
|
||||
else:
|
||||
self._final_failure[node.name] = True
|
||||
completed_jobs.append(j)
|
||||
node.executed = node.job.task_id() if not node_failed else False
|
||||
node.executed = node.job.task_id() if not (node_failed or node.job.is_aborted()) else False
|
||||
if j in launched_nodes:
|
||||
launched_nodes.remove(j)
|
||||
# check if we need to stop all running steps
|
||||
@@ -3685,7 +3697,8 @@ class PipelineDecorator(PipelineController):
|
||||
task_template_header=self._task_template_header,
|
||||
_sanitize_function=sanitize,
|
||||
artifact_serialization_function=self._artifact_serialization_function,
|
||||
artifact_deserialization_function=self._artifact_deserialization_function
|
||||
artifact_deserialization_function=self._artifact_deserialization_function,
|
||||
skip_global_imports=self._skip_global_imports
|
||||
)
|
||||
return task_definition
|
||||
|
||||
@@ -3906,10 +3919,12 @@ class PipelineDecorator(PipelineController):
|
||||
:return: function wrapper
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
_name = name or str(func.__name__)
|
||||
# noinspection PyProtectedMember
|
||||
unwrapped_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||
_name = name or str(unwrapped_func.__name__)
|
||||
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
|
||||
|
||||
inspect_func = inspect.getfullargspec(func)
|
||||
inspect_func = inspect.getfullargspec(unwrapped_func)
|
||||
# add default argument values
|
||||
if inspect_func.args:
|
||||
default_values = list(inspect_func.defaults or [])
|
||||
@@ -4127,7 +4142,7 @@ class PipelineDecorator(PipelineController):
|
||||
return task.artifacts[return_name].get(
|
||||
deserialization_function=cls._singleton._artifact_deserialization_function
|
||||
)
|
||||
return task.get_parameters(cast=True)[CreateFromFunction.return_section + "/" + return_name]
|
||||
return task.get_parameters(cast=True).get(CreateFromFunction.return_section + "/" + return_name)
|
||||
|
||||
return_w = [LazyEvalWrapper(
|
||||
callback=functools.partial(result_wrapper, n),
|
||||
@@ -4172,7 +4187,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@@ -4286,6 +4302,9 @@ class PipelineDecorator(PipelineController):
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
|
||||
global imports will be automatically imported in a safe manner at the beginning of each step’s execution.
|
||||
Default is False
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
|
||||
@@ -4332,7 +4351,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
ret_val = func(**pipeline_kwargs)
|
||||
LazyEvalWrapper.trigger_all_remote_references()
|
||||
@@ -4384,7 +4404,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
output_uri=output_uri,
|
||||
skip_global_imports=skip_global_imports
|
||||
)
|
||||
|
||||
a_pipeline._args_map = args_map or {}
|
||||
@@ -4551,6 +4572,13 @@ class PipelineDecorator(PipelineController):
|
||||
_node.parents = (_node.parents or []) + [
|
||||
x for x in cls._evaluated_return_values.get(tid, []) if x in leaves
|
||||
]
|
||||
|
||||
if not cls._singleton._abort_running_steps_on_failure:
|
||||
for parent in _node.parents:
|
||||
if cls._singleton._nodes[parent].status in ["failed", "aborted", "skipped"]:
|
||||
_node.skip_job = True
|
||||
return
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if v is None or isinstance(v, (float, int, bool, six.string_types)):
|
||||
_node.parameters["{}/{}".format(CreateFromFunction.kwargs_section, k)] = v
|
||||
|
||||
@@ -85,7 +85,7 @@ class _TrainsBandsterWorker(Worker):
|
||||
self.optimizer.budget.iterations.update(self._current_job.task_id(), iteration_value[0])
|
||||
|
||||
# check if we exceeded this job budget
|
||||
if iteration_value[0] >= self.budget_iteration_scale * budget:
|
||||
if iteration_value[0][0] >= self.budget_iteration_scale * budget:
|
||||
self._current_job.abort()
|
||||
break
|
||||
|
||||
@@ -95,7 +95,7 @@ class _TrainsBandsterWorker(Worker):
|
||||
# noinspection PyProtectedMember
|
||||
self.optimizer.budget.jobs.update(
|
||||
self._current_job.task_id(),
|
||||
float(iteration_value[0]) / self.optimizer._max_iteration_per_job)
|
||||
float(iteration_value[0][0]) / self.optimizer._max_iteration_per_job)
|
||||
|
||||
result = {
|
||||
# this is the a mandatory field to run hyperband
|
||||
@@ -104,7 +104,7 @@ class _TrainsBandsterWorker(Worker):
|
||||
# can be used for any user-defined information - also mandatory
|
||||
'info': self._current_job.task_id()
|
||||
}
|
||||
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
|
||||
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value[0]))
|
||||
# noinspection PyProtectedMember
|
||||
self.optimizer._current_jobs.remove(self._current_job)
|
||||
return result
|
||||
@@ -299,7 +299,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
|
||||
sleep_interval=int(self.pool_period_minutes * 60),
|
||||
budget_iteration_scale=budget_iteration_scale,
|
||||
base_task_id=self._base_task_id,
|
||||
objective=self._objective_metric,
|
||||
objective=self._objective_metric.objectives[0],
|
||||
queue_name=self._execution_queue,
|
||||
nameserver='127.0.0.1', nameserver_port=self._nameserver_port, run_id=fake_run_id, id=i)
|
||||
w.run(background=True)
|
||||
|
||||
@@ -7,7 +7,8 @@ from itertools import product
|
||||
from logging import getLogger
|
||||
from threading import Thread, Event
|
||||
from time import time
|
||||
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable
|
||||
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .job import ClearmlJob, LocalClearmlJob
|
||||
from .parameters import Parameter
|
||||
@@ -19,7 +20,33 @@ from ..task import Task
|
||||
logger = getLogger('clearml.automation.optimization')
|
||||
|
||||
|
||||
class Objective(object):
|
||||
class _ObjectiveInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_objective(self, task_id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_current_raw_objective(self, task):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_objective_sign(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_objective_metric(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_normalized_objective(self, task_id):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_top_tasks(self, top_k, optimizer_task_id=None, task_filter=None):
|
||||
pass
|
||||
|
||||
|
||||
class Objective(_ObjectiveInterface):
|
||||
"""
|
||||
Optimization ``Objective`` class to maximize / minimize over all experiments. This class will sample a specific
|
||||
scalar from all experiments, and maximize / minimize over single scalar (i.e., title and series combination).
|
||||
@@ -53,7 +80,7 @@ class Objective(object):
|
||||
self.series = series
|
||||
assert order in ('min', 'max',)
|
||||
# normalize value so we always look for the highest objective value
|
||||
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else +1
|
||||
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else 1
|
||||
self._metric = None
|
||||
self.extremum = extremum
|
||||
|
||||
@@ -243,7 +270,7 @@ class Budget(object):
|
||||
# type: () -> (Optional[float])
|
||||
if self.limit is None or not self.current:
|
||||
return None
|
||||
return sum(self.current.values())/float(self.limit)
|
||||
return sum(self.current.values()) / float(self.limit)
|
||||
|
||||
def __init__(self, jobs_limit, iterations_limit, compute_time_limit):
|
||||
# type: (Optional[int], Optional[int], Optional[float]) -> ()
|
||||
@@ -467,7 +494,7 @@ class SearchStrategy(object):
|
||||
|
||||
if self.max_iteration_per_job:
|
||||
iterations = self._get_job_iterations(job)
|
||||
if iterations > 0:
|
||||
if iterations and iterations > 0:
|
||||
self.budget.iterations.update(job.task_id(), iterations)
|
||||
if iterations > self.max_iteration_per_job:
|
||||
abort_job = True
|
||||
@@ -512,12 +539,8 @@ class SearchStrategy(object):
|
||||
|
||||
:return: A list of Task objects, ordered by performance, where index 0 is the best performing Task.
|
||||
"""
|
||||
# noinspection PyProtectedMember
|
||||
top_tasks = self._get_child_tasks(
|
||||
parent_task_id=self._job_parent_id or self._base_task_id,
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field(),
|
||||
additional_filters={'page_size': int(top_k), 'page': 0})
|
||||
return top_tasks
|
||||
return self._objective_metric.get_top_tasks(top_k=top_k,
|
||||
optimizer_task_id=self._job_parent_id or self._base_task_id)
|
||||
|
||||
def get_top_experiments_id_metrics_pair(self, top_k, all_metrics=False, only_completed=False):
|
||||
# type: (int, bool, bool) -> Sequence[(str, dict)]
|
||||
@@ -582,15 +605,17 @@ class SearchStrategy(object):
|
||||
# noinspection PyProtectedMember
|
||||
top_tasks_ids_metric = self._get_child_tasks_ids(
|
||||
parent_task_id=self._job_parent_id or self._base_task_id,
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field(),
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field()[0],
|
||||
additional_filters=additional_filters,
|
||||
additional_fields=['last_metrics']
|
||||
)
|
||||
|
||||
title, series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
|
||||
title_series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
|
||||
titles = [ts[0] for ts in title_series]
|
||||
series = [ts[1] for ts in title_series]
|
||||
return [(i, {'{}/{}'.format(v['metric'], v['variant']): v
|
||||
for variant in metric.values() for v in variant.values()
|
||||
if all_metrics or v['metric'] == title and v['variant'] == series}
|
||||
if all_metrics or (v['metric'] in titles and v['variant'] in series)}
|
||||
) for i, metric in top_tasks_ids_metric]
|
||||
|
||||
def get_top_experiments_details(
|
||||
@@ -669,15 +694,28 @@ class SearchStrategy(object):
|
||||
# noinspection PyProtectedMember
|
||||
top_tasks_ids_metric_params = self._get_child_tasks_ids(
|
||||
parent_task_id=self._job_parent_id or self._base_task_id,
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field(),
|
||||
order_by=self._objective_metric._get_last_metrics_encode_field()[
|
||||
0] if self._objective_metric.len == 1 else None,
|
||||
additional_filters=additional_filters,
|
||||
additional_fields=['last_metrics', 'hyperparams']
|
||||
)
|
||||
if self._objective_metric.len != 1:
|
||||
top_tasks_ids_metric_params_dict = {}
|
||||
for task in top_tasks_ids_metric_params:
|
||||
objective = self._objective_metric.get_objective(task[0])
|
||||
if objective is None or any(o is None for o in objective):
|
||||
continue
|
||||
top_tasks_ids_metric_params_dict[task[0]] = (objective, task)
|
||||
# noinspection PyProtectedMember
|
||||
sorted_ids = self._objective_metric._sort_jobs_by_domination(top_tasks_ids_metric_params_dict)
|
||||
top_tasks_ids_metric_params = [top_tasks_ids_metric_params_dict[s][1] for s in sorted_ids]
|
||||
|
||||
# get hp_parameters:
|
||||
hp_params = set(p.name for p in self._hyper_parameters)
|
||||
|
||||
title, series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
|
||||
title_series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
|
||||
titles = [ts[0] for ts in title_series]
|
||||
series = [ts[1] for ts in title_series]
|
||||
return [
|
||||
{
|
||||
'task_id': tid,
|
||||
@@ -688,19 +726,20 @@ class SearchStrategy(object):
|
||||
},
|
||||
'metrics': {
|
||||
'{}/{}'.format(v['metric'], v['variant']): v for variant in metric.values()
|
||||
for v in variant.values() if all_metrics or v['metric'] == title and v['variant'] == series
|
||||
for v in variant.values() if all_metrics or v['metric'] in titles and v['variant'] in series
|
||||
}
|
||||
} for tid, metric, param_sections in top_tasks_ids_metric_params
|
||||
]
|
||||
|
||||
def get_objective_metric(self):
|
||||
# type: () -> (str, str)
|
||||
# type: () -> Union[Tuple[str, str], List[Tuple[str, str]]]
|
||||
"""
|
||||
Return the metric title, series pair of the objective.
|
||||
|
||||
:return: (title, series)
|
||||
"""
|
||||
return self._objective_metric.get_objective_metric()
|
||||
objective = self._objective_metric.get_objective_metric()
|
||||
return objective[0] if self._objective_metric.len == 1 else objective
|
||||
|
||||
def helper_create_job(
|
||||
self,
|
||||
@@ -823,7 +862,9 @@ class SearchStrategy(object):
|
||||
def _get_job_iterations(self, job):
|
||||
# type: (Union[ClearmlJob, Task]) -> int
|
||||
iteration_value = self._objective_metric.get_current_raw_objective(job)
|
||||
return iteration_value[0] if iteration_value else -1
|
||||
if iteration_value is not None and any(iv is not None and iv[0] is not None for iv in iteration_value):
|
||||
return max(iv[0] for iv in iteration_value if iv is not None)
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def _get_child_tasks_ids(
|
||||
@@ -887,7 +928,7 @@ class SearchStrategy(object):
|
||||
task_objects = Task._query_tasks(**task_filter)
|
||||
if not additional_fields:
|
||||
return [t.id for t in task_objects]
|
||||
return [[t.id]+[getattr(t, f, None) for f in additional_fields] for t in task_objects]
|
||||
return [[t.id] + [getattr(t, f, None) for f in additional_fields] for t in task_objects]
|
||||
|
||||
@classmethod
|
||||
def _get_child_tasks(
|
||||
@@ -927,6 +968,146 @@ class SearchStrategy(object):
|
||||
]
|
||||
|
||||
|
||||
class MultiObjective(_ObjectiveInterface):
|
||||
def __init__(self, title, series, order, extremum):
|
||||
self.title = title
|
||||
self.series = series
|
||||
self.order = order
|
||||
self.extremum = extremum
|
||||
self.objectives = []
|
||||
for title_, series_, order_, extremum_ in zip(title, series, order, extremum):
|
||||
self.objectives.append(Objective(title=title_, series=series_, order=order_, extremum=extremum_))
|
||||
self.len = len(self.objectives)
|
||||
|
||||
def get_objective(self, task_id):
|
||||
# type: (Union[str, Task, ClearmlJob]) -> Optional[List[float]]
|
||||
"""
|
||||
Return a specific task scalar values based on the objective settings (title/series).
|
||||
|
||||
:param str task_id: The Task ID to retrieve scalar from (or ``ClearMLJob`` object).
|
||||
|
||||
:return: The scalar values.
|
||||
"""
|
||||
objective = [o.get_objective(task_id) for o in self.objectives]
|
||||
if any(o is None for o in objective):
|
||||
return None
|
||||
return objective
|
||||
|
||||
def get_current_raw_objective(self, task):
|
||||
# type: (Union[ClearmlJob, Task]) -> Optional[List[Tuple[int, float]]]
|
||||
"""
|
||||
Return the current raw value (without sign normalization) of each objective.
|
||||
|
||||
:param str task: The Task or Job to retrieve scalar from (or ``ClearmlJob`` object).
|
||||
:return: List[Optional[Tuple(iteration, value)]]. None if the metric does not exist.
|
||||
"""
|
||||
objective = [o.get_current_raw_objective(task) for o in self.objectives]
|
||||
if any(o is None for o in objective):
|
||||
return None
|
||||
return objective
|
||||
|
||||
def get_objective_sign(self):
|
||||
# type: () -> List[float]
|
||||
"""
|
||||
Return the sign of the objectives.
|
||||
|
||||
- ``+1`` - If maximizing
|
||||
- ``-1`` - If minimizing
|
||||
|
||||
:return: Objective function signs.
|
||||
"""
|
||||
return [o.get_objective_sign() for o in self.objectives]
|
||||
|
||||
def get_normalized_objective(self, task_id):
|
||||
# type: (Union[str, Task, ClearmlJob]) -> Optional[List[float]]
|
||||
"""
|
||||
Return a normalized task scalar values based on the objective settings (title/series).
|
||||
I.e. objective is always to maximize the returned value
|
||||
|
||||
:param str task_id: The Task ID to retrieve scalars from.
|
||||
|
||||
:return: Normalized scalar values.
|
||||
"""
|
||||
objective = [o.get_normalized_objective(task_id) for o in self.objectives]
|
||||
if any(o is None for o in objective):
|
||||
return None
|
||||
return objective
|
||||
|
||||
def get_objective_metric(self):
|
||||
# type: () -> List[(str, str)]
|
||||
"""
|
||||
Return the metric title, series pairs of the objectives.
|
||||
|
||||
:return: List[(title, series)]
|
||||
"""
|
||||
return [o.get_objective_metric() for o in self.objectives]
|
||||
|
||||
def get_top_tasks(self, top_k, optimizer_task_id=None, task_filter=None):
|
||||
# type: (int, Optional[str], Optional[dict]) -> Sequence[Task]
|
||||
"""
|
||||
Return a list of Tasks of the top performing experiments.
|
||||
If there is only one objective, the tasks are sorted based on that objective.
|
||||
If there are multiple objectives, the tasks are sorted based on successive Pareto fronts.
|
||||
A trial is located at the Pareto front if there are no trials that dominate the trial.
|
||||
A trial dominates another trial if all its objective metrics are greater or equal than the other
|
||||
trial's and there is at least one objective metric that is strictly greater than the other.
|
||||
|
||||
:param int top_k: The number of Tasks (experiments) to return.
|
||||
:param str optimizer_task_id: Parent optimizer Task ID
|
||||
:param dict task_filter: Optional task_filtering for the query
|
||||
|
||||
:return: A list of Task objects, ordered by performance, where index 0 is the best performing Task.
|
||||
"""
|
||||
if self.len == 1:
|
||||
return self.objectives[0].get_top_tasks(
|
||||
top_k=top_k,
|
||||
optimizer_task_id=optimizer_task_id,
|
||||
task_filter=task_filter
|
||||
)
|
||||
task_filter = deepcopy(task_filter) if task_filter else {}
|
||||
if optimizer_task_id:
|
||||
task_filter["parent"] = optimizer_task_id
|
||||
# noinspection PyProtectedMember
|
||||
tasks = Task._query_tasks(**task_filter)
|
||||
candidates = {}
|
||||
for task in tasks:
|
||||
values = self.get_objective(task.id)
|
||||
if values is None or any(v is None for v in values):
|
||||
continue
|
||||
candidates[task.id] = (values, task.id)
|
||||
sorted_ids = self._sort_jobs_by_domination(candidates)
|
||||
if not sorted_ids:
|
||||
return []
|
||||
return Task.get_tasks(task_ids=sorted_ids[:top_k])
|
||||
|
||||
def _get_last_metrics_encode_field(self):
|
||||
# noinspection PyProtectedMember
|
||||
return [o._get_last_metrics_encode_field() for o in self.objectives]
|
||||
|
||||
def _weakly_dominates_normalized(self, lhs, rhs):
|
||||
return all(lhs_elem * o.sign >= rhs_elem * o.sign for lhs_elem, rhs_elem, o in zip(lhs, rhs, self.objectives))
|
||||
|
||||
@staticmethod
|
||||
def _dominates(lhs, rhs):
|
||||
return all(lhs_elem >= rhs_elem for lhs_elem, rhs_elem in zip(lhs, rhs)) and \
|
||||
any(lhs_elem > rhs_elem for lhs_elem, rhs_elem in zip(lhs, rhs))
|
||||
|
||||
def _sort_jobs_by_domination(self, jobs):
|
||||
job_ids = list(jobs.keys())
|
||||
job_ids_sorted = []
|
||||
while len(job_ids_sorted) < len(jobs.keys()):
|
||||
have_result = False
|
||||
for job_id in job_ids:
|
||||
if all(self._weakly_dominates_normalized(jobs[job_id][0], jobs[other_job_id][0])
|
||||
for other_job_id in job_ids):
|
||||
have_result = True
|
||||
job_ids_sorted.append(job_id)
|
||||
if not have_result:
|
||||
job_ids_sorted.extend(job_ids)
|
||||
job_ids = [job_id for job_id in job_ids if job_id not in job_ids_sorted]
|
||||
return job_ids_sorted
|
||||
|
||||
|
||||
class GridSearch(SearchStrategy):
|
||||
"""
|
||||
Grid search strategy controller. Full grid sampling of every hyperparameter combination.
|
||||
@@ -1089,12 +1270,12 @@ class HyperParameterOptimizer(object):
|
||||
self,
|
||||
base_task_id, # type: str
|
||||
hyper_parameters, # type: Sequence[Parameter]
|
||||
objective_metric_title, # type: str
|
||||
objective_metric_series, # type: str
|
||||
objective_metric_sign='min', # type: str
|
||||
objective_metric_title, # type: Union[str, Sequence[str]]
|
||||
objective_metric_series, # type: Union[str, Sequence[str]]
|
||||
objective_metric_sign="min", # type: Union[str, Sequence[str]]
|
||||
optimizer_class=RandomSearch, # type: Union[SearchStrategy, type(SearchStrategy)]
|
||||
max_number_of_concurrent_tasks=10, # type: int
|
||||
execution_queue='default', # type: str
|
||||
execution_queue="default", # type: str
|
||||
optimization_time_limit=None, # type: Optional[float]
|
||||
compute_time_limit=None, # type: Optional[float]
|
||||
auto_connect_task=True, # type: Union[bool, Task]
|
||||
@@ -1109,10 +1290,14 @@ class HyperParameterOptimizer(object):
|
||||
|
||||
:param str base_task_id: The Task ID to be used as template experiment to optimize.
|
||||
:param list hyper_parameters: The list of Parameter objects to optimize over.
|
||||
:param str objective_metric_title: The Objective metric title to maximize / minimize (for example,
|
||||
``validation``).
|
||||
:param str objective_metric_series: The Objective metric series to maximize / minimize (for example, ``loss``).
|
||||
:param str objective_metric_sign: The objective to maximize / minimize.
|
||||
:param Union[str, Sequence[str]] objective_metric_title: The Objective metric title(s) to maximize / minimize
|
||||
(for example, ``validation``, ``["validation", "loss"]``). If ``objective_metric_title`` is a sequence
|
||||
(used to optimize multiple objectives at the same time), then ``objective_metric_series`` and
|
||||
``objective_metric_sign`` have to be sequences of the same length. Each title will be matched
|
||||
with the respective series and sign
|
||||
:param Union[str, Sequence[str]] objective_metric_series: The Objective metric series to maximize / minimize
|
||||
(for example, ``loss_series``, ``["validation_series", "loss_series"]``).
|
||||
:param Union[str, Sequence[str]] objective_metric_sign: The objectives to maximize / minimize.
|
||||
The values are:
|
||||
|
||||
- ``min`` - Minimize the last reported value for the specified title/series scalar.
|
||||
@@ -1190,7 +1375,24 @@ class HyperParameterOptimizer(object):
|
||||
# make sure we stop all jobs
|
||||
an_optimizer.stop()
|
||||
"""
|
||||
|
||||
if type(objective_metric_title) is not type(objective_metric_series) or type(
|
||||
objective_metric_title) is not type(objective_metric_sign):
|
||||
raise TypeError(
|
||||
"objective_metric_series, objective_metric_title and objective_metric_sign have to be of the same type"
|
||||
" (strings if doing single objective optimization and lists of the same length"
|
||||
" if doing multi-objective optimization)"
|
||||
)
|
||||
if isinstance(objective_metric_title, str):
|
||||
objective_metric_series = [objective_metric_series]
|
||||
objective_metric_title = [objective_metric_title]
|
||||
objective_metric_sign = [objective_metric_sign]
|
||||
if len(objective_metric_series) != len(objective_metric_title) or len(objective_metric_series) != len(
|
||||
objective_metric_sign
|
||||
):
|
||||
raise ValueError(
|
||||
"Can not use multiple objective optimization when objective_metric_series, objective_metric_title"
|
||||
" or objective_metric_sign do not have the same length"
|
||||
)
|
||||
# create a new Task, if we do not have one already
|
||||
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
|
||||
self._readonly_task = \
|
||||
@@ -1224,17 +1426,29 @@ class HyperParameterOptimizer(object):
|
||||
self.hyper_parameters = hyper_parameters
|
||||
self.max_number_of_concurrent_tasks = opts['max_number_of_concurrent_tasks']
|
||||
self.execution_queue = opts['execution_queue']
|
||||
self.objective_metric = Objective(
|
||||
title=opts['objective_metric_title'], series=opts['objective_metric_series'],
|
||||
order='min' if opts['objective_metric_sign'] in ('min', 'min_global') else 'max',
|
||||
extremum=opts['objective_metric_sign'].endswith('_global'))
|
||||
self._objective_metric = MultiObjective(
|
||||
title=opts["objective_metric_title"],
|
||||
series=opts["objective_metric_series"],
|
||||
order=["min" if sign_ in ("min", "min_global") else "max" for sign_ in opts["objective_metric_sign"]],
|
||||
extremum=[sign_.endswith("_global") for sign_ in opts["objective_metric_sign"]]
|
||||
)
|
||||
optuna_error_message = "Multi parameter optimization is only supported via Optuna. Please install Optuna via" + \
|
||||
" `pip install optuna and set the `optimizer_class` to `clearml.automation.optuna.OptimizerOptuna`"
|
||||
try:
|
||||
if self._objective_metric.len != 1:
|
||||
from .optuna import OptimizerOptuna
|
||||
|
||||
if optimizer_class != OptimizerOptuna:
|
||||
raise ValueError(optuna_error_message)
|
||||
except Exception:
|
||||
raise ValueError(optuna_error_message)
|
||||
# if optimizer_class is an instance, use it as is.
|
||||
if type(optimizer_class) != type:
|
||||
if not isinstance(optimizer_class, type):
|
||||
self.optimizer = optimizer_class
|
||||
else:
|
||||
self.optimizer = optimizer_class(
|
||||
base_task_id=opts['base_task_id'], hyper_parameters=hyper_parameters,
|
||||
objective_metric=self.objective_metric, execution_queue=opts['execution_queue'],
|
||||
objective_metric=self._objective_metric, execution_queue=opts['execution_queue'],
|
||||
num_concurrent_workers=opts['max_number_of_concurrent_tasks'],
|
||||
compute_time_limit=opts['compute_time_limit'], **opts.get('optimizer_kwargs', {}))
|
||||
self.optimizer.set_optimizer_task(self._task)
|
||||
@@ -1607,9 +1821,9 @@ class HyperParameterOptimizer(object):
|
||||
@classmethod
|
||||
def get_optimizer_top_experiments(
|
||||
cls,
|
||||
objective_metric_title, # type: str
|
||||
objective_metric_series, # type: str
|
||||
objective_metric_sign, # type: str
|
||||
objective_metric_title, # type: Union[str, List[str]]
|
||||
objective_metric_series, # type: Union[str, List[str]]
|
||||
objective_metric_sign, # type: Union[str, List[str]]
|
||||
optimizer_task_id, # type: str
|
||||
top_k, # type: int
|
||||
):
|
||||
@@ -1636,6 +1850,12 @@ class HyperParameterOptimizer(object):
|
||||
title=objective_metric_title, series=objective_metric_series, order=objective_metric_sign)
|
||||
return objective.get_top_tasks(top_k=top_k, optimizer_task_id=optimizer_task_id)
|
||||
|
||||
@property
|
||||
def objective_metric(self):
|
||||
if self._objective_metric.len == 1:
|
||||
return self._objective_metric.objectives[0]
|
||||
return self._objective_metric
|
||||
|
||||
def _connect_args(self, optimizer_class=None, hyper_param_configuration=None, **kwargs):
|
||||
# type: (SearchStrategy, dict, Any) -> (SearchStrategy, list, dict)
|
||||
if not self._task or self._readonly_task:
|
||||
@@ -1705,8 +1925,8 @@ class HyperParameterOptimizer(object):
|
||||
|
||||
def _report_daemon(self):
|
||||
# type: () -> ()
|
||||
title, series = self.objective_metric.get_objective_metric()
|
||||
title = '{}/{}'.format(title, series)
|
||||
title_series = self._objective_metric.get_objective_metric()
|
||||
title = ["{}/{}".format(ts[0], ts[1]) for ts in title_series]
|
||||
counter = 0
|
||||
completed_jobs = dict()
|
||||
task_logger = None
|
||||
@@ -1722,10 +1942,14 @@ class HyperParameterOptimizer(object):
|
||||
params["status"] = str(task.status)
|
||||
# noinspection PyProtectedMember
|
||||
iteration_value = task.get_last_iteration()
|
||||
objective = self.objective_metric.get_objective(task)
|
||||
objective = self._objective_metric.get_objective(task)
|
||||
completed_jobs[task.id] = (
|
||||
objective if objective is not None else -1,
|
||||
iteration_value if iteration_value is not None else -1,
|
||||
objective if objective is not None else (
|
||||
[-1] * self._objective_metric.len
|
||||
),
|
||||
iteration_value if iteration_value is not None else (
|
||||
[-1] * self._objective_metric.len
|
||||
),
|
||||
params
|
||||
)
|
||||
|
||||
@@ -1754,9 +1978,9 @@ class HyperParameterOptimizer(object):
|
||||
self._report_remaining_budget(task_logger, counter)
|
||||
|
||||
if (
|
||||
self.optimizer.budget.compute_time.used
|
||||
and self.optimizer.budget.compute_time.limit
|
||||
and self.optimizer.budget.compute_time.used >= self.optimizer.budget.compute_time.limit
|
||||
self.optimizer.budget.compute_time.used
|
||||
and self.optimizer.budget.compute_time.limit
|
||||
and self.optimizer.budget.compute_time.used >= self.optimizer.budget.compute_time.limit
|
||||
):
|
||||
logger.warning(
|
||||
"Optimizer task reached compute time limit (used {:.2f} out of {:.2f})".format(
|
||||
@@ -1767,8 +1991,9 @@ class HyperParameterOptimizer(object):
|
||||
|
||||
self._report_resources(task_logger, counter)
|
||||
# collect a summary of all the jobs and their final objective values
|
||||
cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - \
|
||||
{j.task_id() for j in self.optimizer.get_running_jobs()}
|
||||
cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - {
|
||||
j.task_id() for j in self.optimizer.get_running_jobs()
|
||||
}
|
||||
self._report_completed_status(completed_jobs, cur_completed_jobs, task_logger, title)
|
||||
self._report_completed_tasks_best_results(set(completed_jobs.keys()), task_logger, title, counter)
|
||||
|
||||
@@ -1790,10 +2015,14 @@ class HyperParameterOptimizer(object):
|
||||
|
||||
def _report_completed_status(self, completed_jobs, cur_completed_jobs, task_logger, title, force=False):
|
||||
job_ids_sorted_by_objective = self.__sort_jobs_by_objective(completed_jobs)
|
||||
best_experiment = \
|
||||
(self.objective_metric.get_normalized_objective(job_ids_sorted_by_objective[0]),
|
||||
job_ids_sorted_by_objective[0]) \
|
||||
if job_ids_sorted_by_objective else (float('-inf'), None)
|
||||
best_experiment = (
|
||||
(
|
||||
self._objective_metric.get_normalized_objective(job_ids_sorted_by_objective[0]),
|
||||
job_ids_sorted_by_objective[0],
|
||||
)
|
||||
if job_ids_sorted_by_objective
|
||||
else ([float("-inf")], None)
|
||||
)
|
||||
if force or cur_completed_jobs != set(completed_jobs.keys()):
|
||||
pairs = []
|
||||
labels = []
|
||||
@@ -1801,13 +2030,14 @@ class HyperParameterOptimizer(object):
|
||||
created_jobs_tasks = self.optimizer.get_created_jobs_tasks()
|
||||
id_status = {j_id: j_run.status() for j_id, j_run in created_jobs_tasks.items()}
|
||||
for i, (job_id, params) in enumerate(created_jobs.items()):
|
||||
value = self.objective_metric.get_objective(job_id)
|
||||
value = self._objective_metric.get_objective(job_id)
|
||||
if job_id in completed_jobs:
|
||||
if value != completed_jobs[job_id][0]:
|
||||
iteration_value = self.objective_metric.get_current_raw_objective(job_id)
|
||||
iteration_value = self._objective_metric.get_current_raw_objective(job_id)
|
||||
iteration = [it_[0] if it_ else -1 for it_ in iteration_value]
|
||||
completed_jobs[job_id] = (
|
||||
value,
|
||||
iteration_value[0] if iteration_value else -1,
|
||||
iteration,
|
||||
copy(dict(status=id_status.get(job_id), **params)))
|
||||
elif completed_jobs.get(job_id):
|
||||
completed_jobs[job_id] = (completed_jobs[job_id][0],
|
||||
@@ -1815,43 +2045,98 @@ class HyperParameterOptimizer(object):
|
||||
copy(dict(status=id_status.get(job_id), **params)))
|
||||
pairs.append((i, completed_jobs[job_id][0]))
|
||||
labels.append(str(completed_jobs[job_id][2])[1:-1])
|
||||
elif value is not None:
|
||||
elif value is not None and all(v is not None for v in value):
|
||||
pairs.append((i, value))
|
||||
labels.append(str(params)[1:-1])
|
||||
iteration_value = self.objective_metric.get_current_raw_objective(job_id)
|
||||
iteration_value = self._objective_metric.get_current_raw_objective(job_id)
|
||||
iteration = [it_[0] if it_ else -1 for it_ in iteration_value]
|
||||
completed_jobs[job_id] = (
|
||||
value,
|
||||
iteration_value[0] if iteration_value else -1,
|
||||
iteration,
|
||||
copy(dict(status=id_status.get(job_id), **params)))
|
||||
# callback new experiment completed
|
||||
if self._experiment_completed_cb:
|
||||
normalized_value = self.objective_metric.get_normalized_objective(job_id)
|
||||
if normalized_value is not None and normalized_value > best_experiment[0]:
|
||||
normalized_value = self._objective_metric.get_normalized_objective(job_id)
|
||||
if self._objective_metric.len == 1 and normalized_value is not None and \
|
||||
normalized_value[0] > best_experiment[0][0]:
|
||||
best_experiment = normalized_value, job_id
|
||||
elif self._objective_metric.len != 1 and normalized_value is not None and \
|
||||
all(n is not None for n in normalized_value) and (best_experiment[0] == float("-inf") or
|
||||
MultiObjective._dominates(
|
||||
normalized_value,
|
||||
best_experiment[0])): # noqa
|
||||
best_experiment = normalized_value, job_id
|
||||
c = completed_jobs[job_id]
|
||||
self._experiment_completed_cb(job_id, c[0], c[1], c[2], best_experiment[1])
|
||||
|
||||
if pairs:
|
||||
print('Updating job performance summary plot/table')
|
||||
|
||||
# update scatter plot
|
||||
task_logger.report_scatter2d(
|
||||
title='Optimization Objective', series=title,
|
||||
scatter=pairs, iteration=0, labels=labels,
|
||||
mode='markers', xaxis='job #', yaxis='objective')
|
||||
print("Updating job performance summary plot/table")
|
||||
if isinstance(title, list):
|
||||
for i, title_ in enumerate(title):
|
||||
# update scatter plot
|
||||
task_logger.report_scatter2d(
|
||||
title="Optimization Objective",
|
||||
series=title_,
|
||||
scatter=[(p[0], p[1][i]) for p in pairs],
|
||||
iteration=0,
|
||||
labels=labels,
|
||||
mode="markers",
|
||||
xaxis="job #",
|
||||
yaxis="objective",
|
||||
)
|
||||
else:
|
||||
task_logger.report_scatter2d(
|
||||
title="Optimization Objective",
|
||||
series=title,
|
||||
scatter=pairs,
|
||||
iteration=0,
|
||||
labels=labels,
|
||||
mode="markers",
|
||||
xaxis="job #",
|
||||
yaxis="objective",
|
||||
)
|
||||
|
||||
# update summary table
|
||||
job_ids = list(completed_jobs.keys())
|
||||
job_ids_sorted_by_objective = sorted(
|
||||
job_ids, key=lambda x: completed_jobs[x][0], reverse=bool(self.objective_metric.sign >= 0))
|
||||
job_ids_sorted_by_objective = self.__sort_jobs_by_objective(completed_jobs)
|
||||
# sort the columns except for 'objective', 'iteration'
|
||||
columns = list(sorted(set([c for k, v in completed_jobs.items() for c in v[2].keys()])))
|
||||
|
||||
# add the index column (task id) and the first two columns 'objective', 'iteration' then the rest
|
||||
table_values = [['task id', 'objective', 'iteration'] + columns]
|
||||
table_values += \
|
||||
[([job, completed_jobs[job][0], completed_jobs[job][1]] +
|
||||
[completed_jobs[job][2].get(c, '') for c in columns]) for job in job_ids_sorted_by_objective]
|
||||
concat_iterations = True
|
||||
if self._objective_metric.len == 1:
|
||||
# add the index column (task id) and the first two columns 'objective', 'iteration' then the rest
|
||||
table_values = [['task id', 'objective', 'iteration'] + columns]
|
||||
table_values += \
|
||||
[([job, completed_jobs[job][0][0], completed_jobs[job][1][0]] +
|
||||
[completed_jobs[job][2].get(c, '') for c in columns]) for job in job_ids_sorted_by_objective]
|
||||
else:
|
||||
table_values = ['task id']
|
||||
for job in job_ids_sorted_by_objective:
|
||||
if not all(iter_ == completed_jobs[job][1][0] for iter_ in completed_jobs[job][1]):
|
||||
concat_iterations = False
|
||||
break
|
||||
if concat_iterations:
|
||||
for objective in self._objective_metric.objectives:
|
||||
table_values.append(objective.title + "/" + objective.series)
|
||||
table_values.append("iteration")
|
||||
table_values = [table_values + columns]
|
||||
for job in job_ids_sorted_by_objective:
|
||||
entry = [job]
|
||||
for val in completed_jobs[job][0]:
|
||||
entry += [val]
|
||||
entry += [completed_jobs[job][1][0]]
|
||||
entry += [completed_jobs[job][2].get(c, '') for c in columns]
|
||||
table_values.append(entry)
|
||||
else:
|
||||
for objective in self._objective_metric.objectives:
|
||||
table_values.append(objective.title + "/" + objective.series)
|
||||
table_values.append("iteration " + objective.title + "/" + objective.series)
|
||||
table_values = [table_values + columns]
|
||||
for job in job_ids_sorted_by_objective:
|
||||
entry = [job]
|
||||
for val, iter_ in zip(completed_jobs[job][0], completed_jobs[job][1]):
|
||||
entry += [val, iter_]
|
||||
entry += [completed_jobs[job][2].get(c, '') for c in columns]
|
||||
table_values.append(entry)
|
||||
|
||||
# create links for task id in the table
|
||||
task_link_template = self._task.get_output_log_web_page() \
|
||||
@@ -1867,15 +2152,42 @@ class HyperParameterOptimizer(object):
|
||||
task_link_template.format(project=project_id, task=task_id), task_id)
|
||||
|
||||
task_logger.report_table(
|
||||
"summary", "job", 0, table_plot=table_values_with_links,
|
||||
extra_layout={"title": "objective: {}".format(title)})
|
||||
"summary",
|
||||
"job",
|
||||
0,
|
||||
table_plot=table_values_with_links,
|
||||
extra_layout={
|
||||
"title": "objective: {}".format(title if not isinstance(title, list) else ", ".join(title))
|
||||
},
|
||||
)
|
||||
|
||||
# Build parallel Coordinates: convert to columns, and reorder accordingly
|
||||
if len(table_values) > 1:
|
||||
table_values_columns = [[row[i] for row in table_values] for i in range(len(table_values[0]))]
|
||||
table_values_columns = \
|
||||
[[table_values_columns[0][0]] + [c[:6]+'...' for c in table_values_columns[0][1:]]] + \
|
||||
table_values_columns[2:-1] + [[title]+table_values_columns[1][1:]]
|
||||
if self._objective_metric.len == 1:
|
||||
table_values_columns = \
|
||||
[[table_values_columns[0][0]] + [c[:6] + '...' for c in table_values_columns[0][1:]]] + \
|
||||
table_values_columns[2:-1] + [[title] + table_values_columns[1][1:]]
|
||||
else:
|
||||
if not concat_iterations:
|
||||
new_table_values_columns = []
|
||||
handled = []
|
||||
for i in range(1, 2 * len(self._objective_metric.objectives), 2):
|
||||
handled.append(i)
|
||||
new_table_values_columns.append(table_values_columns[i])
|
||||
prefix = []
|
||||
for i in range(len(table_values_columns)):
|
||||
if i in handled or table_values_columns[i][0] == "status":
|
||||
continue
|
||||
prefix.append(table_values_columns[i])
|
||||
table_values_columns = prefix + new_table_values_columns
|
||||
else:
|
||||
table_values_columns = ([table_values_columns[0]] +
|
||||
table_values_columns[len(self._objective_metric.objectives) + 1:-1] +
|
||||
table_values_columns[1:len(self._objective_metric.objectives) + 1]
|
||||
)
|
||||
for i in range(len(table_values_columns[0]) - 1):
|
||||
table_values_columns[0][i + 1] = table_values_columns[0][i + 1][:6] + "..."
|
||||
pcc_dims = []
|
||||
for col in table_values_columns:
|
||||
# test if all values are numbers:
|
||||
@@ -1896,16 +2208,22 @@ class HyperParameterOptimizer(object):
|
||||
pcc_dims.append(d)
|
||||
# report parallel coordinates
|
||||
plotly_pcc = dict(
|
||||
data=[dict(
|
||||
type='parcoords',
|
||||
line=dict(colorscale='Viridis',
|
||||
reversescale=bool(self.objective_metric.sign >= 0),
|
||||
color=table_values_columns[-1][1:]),
|
||||
dimensions=pcc_dims)],
|
||||
layout={})
|
||||
task_logger.report_plotly(
|
||||
title='Parallel Coordinates', series='',
|
||||
iteration=0, figure=plotly_pcc)
|
||||
data=[
|
||||
dict(
|
||||
type="parcoords",
|
||||
line=dict(
|
||||
colorscale="Viridis",
|
||||
reversescale=(
|
||||
self._objective_metric.len == 1 and self._objective_metric.objectives[0].sign >= 0,
|
||||
),
|
||||
color=table_values_columns[-1][1:],
|
||||
),
|
||||
dimensions=pcc_dims,
|
||||
)
|
||||
],
|
||||
layout={},
|
||||
)
|
||||
task_logger.report_plotly(title="Parallel Coordinates", series="", iteration=0, figure=plotly_pcc)
|
||||
|
||||
# upload summary as artifact
|
||||
if force:
|
||||
@@ -1937,21 +2255,20 @@ class HyperParameterOptimizer(object):
|
||||
if not completed_jobs:
|
||||
return
|
||||
|
||||
value_func, series_name = (max, "max") if self.objective_metric.get_objective_sign() > 0 else \
|
||||
(min, "min")
|
||||
latest_completed, obj_values = self._get_latest_completed_task_value(completed_jobs, series_name)
|
||||
if latest_completed:
|
||||
val = value_func(obj_values)
|
||||
task_logger.report_scalar(
|
||||
title=title,
|
||||
series=series_name,
|
||||
iteration=counter,
|
||||
value=val)
|
||||
task_logger.report_scalar(
|
||||
title=title,
|
||||
series="last reported",
|
||||
iteration=counter,
|
||||
value=latest_completed)
|
||||
objectives = self._objective_metric.objectives
|
||||
if not isinstance(title, list):
|
||||
title = [title]
|
||||
for objective, title_ in zip(objectives, title):
|
||||
value_func, series_name = (max, "max") if objective.get_objective_sign() > 0 else (min, "min")
|
||||
latest_completed, obj_values = self._get_latest_completed_task_value(
|
||||
completed_jobs, series_name, objective.title, objective.series
|
||||
)
|
||||
if latest_completed:
|
||||
val = value_func(obj_values)
|
||||
task_logger.report_scalar(title=title_, series=series_name, iteration=counter, value=val)
|
||||
task_logger.report_scalar(
|
||||
title=title_, series="last reported", iteration=counter, value=latest_completed
|
||||
)
|
||||
|
||||
def _report_resources(self, task_logger, iteration):
|
||||
# type: (Logger, int) -> ()
|
||||
@@ -1990,7 +2307,7 @@ class HyperParameterOptimizer(object):
|
||||
title="resources", series=series,
|
||||
iteration=iteration, value=val)
|
||||
|
||||
def _get_latest_completed_task_value(self, cur_completed_jobs, series_name):
|
||||
def _get_latest_completed_task_value(self, cur_completed_jobs, series_name, title, series):
|
||||
# type: (Set[str], str) -> (float, List[float])
|
||||
completed_value = None
|
||||
latest_completed = None
|
||||
@@ -2003,16 +2320,15 @@ class HyperParameterOptimizer(object):
|
||||
continue
|
||||
completed_time = datetime_from_isoformat(response.response_data["task"]["completed"].partition("+")[0])
|
||||
completed_time = completed_time.timestamp()
|
||||
completed_values = self._get_last_value(response)
|
||||
completed_values = self._get_last_value(response, title, series)
|
||||
obj_values.append(completed_values['max_value'] if series_name == "max" else completed_values['min_value'])
|
||||
if not latest_completed or completed_time > latest_completed:
|
||||
latest_completed = completed_time
|
||||
completed_value = completed_values['value']
|
||||
return completed_value, obj_values
|
||||
|
||||
def _get_last_value(self, response):
|
||||
metrics, title, series, values = ClearmlJob.get_metric_req_params(self.objective_metric.title,
|
||||
self.objective_metric.series)
|
||||
def _get_last_value(self, response, title, series):
|
||||
metrics, title, series, values = ClearmlJob.get_metric_req_params(title, series)
|
||||
last_values = response.response_data["task"]['last_metrics'][title][series]
|
||||
return last_values
|
||||
|
||||
@@ -2061,6 +2377,14 @@ class HyperParameterOptimizer(object):
|
||||
def __sort_jobs_by_objective(self, completed_jobs):
|
||||
if not completed_jobs:
|
||||
return []
|
||||
job_ids_sorted_by_objective = list(sorted(
|
||||
completed_jobs.keys(), key=lambda x: completed_jobs[x][0], reverse=bool(self.objective_metric.sign >= 0)))
|
||||
return job_ids_sorted_by_objective
|
||||
if self._objective_metric.len != 1:
|
||||
# noinspection PyProtectedMember
|
||||
return self._objective_metric._sort_jobs_by_domination(completed_jobs)
|
||||
else:
|
||||
return list(
|
||||
sorted(
|
||||
completed_jobs.keys(),
|
||||
key=lambda x: completed_jobs[x][0],
|
||||
reverse=bool(self._objective_metric.objectives[0].sign >= 0),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -55,21 +55,29 @@ class OptunaObjective(object):
|
||||
if not is_pending:
|
||||
# noinspection PyProtectedMember
|
||||
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(current_job)
|
||||
if not iteration_value:
|
||||
if not self.optimizer.monitor_job(current_job):
|
||||
break
|
||||
continue
|
||||
|
||||
# make sure we skip None objective values
|
||||
if iteration_value and iteration_value[1] is not None:
|
||||
# update budget
|
||||
trial.report(value=iteration_value[1], step=iteration_value[0])
|
||||
if not any(val is None or val[1] is None for val in iteration_value):
|
||||
iteration = max(iv[0] for iv in iteration_value)
|
||||
# trial pruning based on intermediate values not supported when using multi-objective
|
||||
# noinspection PyProtectedMember
|
||||
if self.optimizer._objective_metric.len == 1:
|
||||
# update budget
|
||||
trial.report(value=iteration_value[0][1], step=iteration)
|
||||
|
||||
# Handle pruning based on the intermediate value.
|
||||
if trial.should_prune() and (
|
||||
not self.min_iteration_per_job or
|
||||
iteration_value[0] >= self.min_iteration_per_job):
|
||||
current_job.abort()
|
||||
raise optuna.TrialPruned()
|
||||
# Handle pruning based on the intermediate value.
|
||||
if trial.should_prune() and (
|
||||
not self.min_iteration_per_job or
|
||||
iteration >= self.min_iteration_per_job):
|
||||
current_job.abort()
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
# check if we exceeded this job budget
|
||||
if self.max_iteration_per_job and iteration_value[0] >= self.max_iteration_per_job:
|
||||
if self.max_iteration_per_job and iteration >= self.max_iteration_per_job:
|
||||
current_job.abort()
|
||||
break
|
||||
|
||||
@@ -79,6 +87,10 @@ class OptunaObjective(object):
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
objective_metric = self.optimizer._objective_metric.get_objective(current_job)
|
||||
# noinspection PyProtectedMember
|
||||
if self.optimizer._objective_metric.len == 1:
|
||||
objective_metric = objective_metric[0]
|
||||
iteration_value = iteration_value[0]
|
||||
print('OptunaObjective result metric={}, iteration {}'.format(objective_metric, iteration_value))
|
||||
# noinspection PyProtectedMember
|
||||
self.optimizer._current_jobs.remove(current_job)
|
||||
@@ -157,13 +169,22 @@ class OptimizerOptuna(SearchStrategy):
|
||||
This function returns only after optimization is completed or :meth:`stop` was called.
|
||||
|
||||
"""
|
||||
self._study = optuna.create_study(
|
||||
direction="minimize" if self._objective_metric.get_objective_sign() < 0 else "maximize",
|
||||
load_if_exists=False,
|
||||
sampler=self._optuna_sampler,
|
||||
pruner=self._optuna_pruner,
|
||||
study_name=self._optimizer_task.id if self._optimizer_task else None,
|
||||
)
|
||||
if self._objective_metric.len != 1:
|
||||
self._study = optuna.create_study(
|
||||
directions=["minimize" if sign_ < 0 else "maximize" for sign_ in self._objective_metric.get_objective_sign()],
|
||||
load_if_exists=False,
|
||||
sampler=self._optuna_sampler,
|
||||
pruner=self._optuna_pruner,
|
||||
study_name=self._optimizer_task.id if self._optimizer_task else None,
|
||||
)
|
||||
else:
|
||||
self._study = optuna.create_study(
|
||||
direction="minimize" if self._objective_metric.get_objective_sign()[0] < 0 else "maximize",
|
||||
load_if_exists=False,
|
||||
sampler=self._optuna_sampler,
|
||||
pruner=self._optuna_pruner,
|
||||
study_name=self._optimizer_task.id if self._optimizer_task else None,
|
||||
)
|
||||
config_space = self._convert_hyper_parameters_to_optuna()
|
||||
self._objective = OptunaObjective(
|
||||
base_task_id=self._base_task_id,
|
||||
|
||||
@@ -9,6 +9,7 @@ from attr import attrs, attrib
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .job import ClearmlJob
|
||||
from .controller import PipelineController
|
||||
from ..backend_interface.util import datetime_from_isoformat, datetime_to_isoformat, mutually_exclusive
|
||||
from ..task import Task
|
||||
|
||||
@@ -59,6 +60,23 @@ class BaseScheduleJob(object):
|
||||
self._executed_instances = []
|
||||
self._executed_instances.append(str(task_id))
|
||||
|
||||
def get_resolved_target_project(self):
|
||||
if not self.base_task_id or not self.target_project:
|
||||
return self.target_project
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
task = Task.get_task(task_id=self.base_task_id)
|
||||
# noinspection PyProtectedMember
|
||||
if (
|
||||
PipelineController._tag in task.get_system_tags()
|
||||
and "/{}/".format(PipelineController._project_section) not in self.target_project
|
||||
):
|
||||
# noinspection PyProtectedMember
|
||||
return "{}/{}/{}".format(self.target_project, PipelineController._project_section, task.name)
|
||||
except Exception:
|
||||
pass
|
||||
return self.target_project
|
||||
|
||||
|
||||
@attrs
|
||||
class ScheduleJob(BaseScheduleJob):
|
||||
@@ -367,7 +385,13 @@ class BaseScheduler(object):
|
||||
|
||||
:param queue: Remote queue to run the scheduler on, default 'services' queue.
|
||||
"""
|
||||
# make sure we serialize the current state if we are running locally
|
||||
if self._task.running_locally():
|
||||
self._serialize_state()
|
||||
self._serialize()
|
||||
# launch on the remote agent
|
||||
self._task.execute_remotely(queue_name=queue, exit_process=True)
|
||||
# we will be deserializing the state inside `start`
|
||||
self.start()
|
||||
|
||||
def _update_execution_plots(self):
|
||||
@@ -447,7 +471,7 @@ class BaseScheduler(object):
|
||||
task_overrides=job.task_overrides,
|
||||
disable_clone_task=not job.clone_task,
|
||||
allow_caching=False,
|
||||
target_project=job.target_project,
|
||||
target_project=job.get_resolved_target_project(),
|
||||
tags=[add_tags] if add_tags and isinstance(add_tags, str) else add_tags,
|
||||
)
|
||||
self._log('Scheduling Job {}, Task {} on queue {}.'.format(
|
||||
|
||||
@@ -38,6 +38,7 @@ from .defs import (
|
||||
from .request import Request, BatchRequest # noqa: F401
|
||||
from .token_manager import TokenManager
|
||||
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
|
||||
from ...backend_config import ConfigurationError
|
||||
from ...backend_config.defs import get_config_file
|
||||
from ...debugging import get_logger
|
||||
from ...debugging.log import resolve_logging_level
|
||||
@@ -773,7 +774,7 @@ class Session(TokenManager):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cls()
|
||||
except MissingConfigError:
|
||||
except (MissingConfigError, ConfigurationError):
|
||||
if raise_error and not ENV_IGNORE_MISSING_CONFIG.get():
|
||||
raise
|
||||
except LoginError:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import base64
|
||||
from distutils.util import strtobool
|
||||
from typing import Union, Optional, Any, TypeVar, Callable, Tuple
|
||||
|
||||
import six
|
||||
@@ -14,6 +13,22 @@ except ImportError:
|
||||
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
|
||||
|
||||
|
||||
def strtobool(val):
|
||||
"""Convert a string representation of truth to true (1) or false (0).
|
||||
|
||||
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
|
||||
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
|
||||
'val' is anything else.
|
||||
"""
|
||||
val = val.lower()
|
||||
if val in ('y', 'yes', 't', 'true', 'on', '1'):
|
||||
return 1
|
||||
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
|
||||
return 0
|
||||
else:
|
||||
raise ValueError("invalid truth value %r" % (val,))
|
||||
|
||||
|
||||
def base64_to_text(value):
|
||||
# type: (Any) -> Text
|
||||
return base64.b64decode(value).decode("utf-8")
|
||||
|
||||
@@ -528,8 +528,10 @@ class _Arguments(object):
|
||||
"Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v)
|
||||
)
|
||||
|
||||
# assume more general purpose type int -> float
|
||||
if v_type == int:
|
||||
# if parameter is empty and default value is None, keep as None
|
||||
if param == '' and v is None:
|
||||
v_type = type(None)
|
||||
elif v_type == int: # assume more general purpose type int -> float
|
||||
if v is not None and int(v) != float(v):
|
||||
v_type = float
|
||||
elif v_type == bool:
|
||||
|
||||
@@ -17,16 +17,14 @@ from ...task import Task
|
||||
|
||||
|
||||
class CreateAndPopulate(object):
|
||||
_VCS_SSH_REGEX = \
|
||||
"^" \
|
||||
"(?:(?P<user>{regular}*?)@)?" \
|
||||
"(?P<host>{regular}*?)" \
|
||||
":" \
|
||||
"(?P<path>{regular}.*)?" \
|
||||
"$" \
|
||||
.format(
|
||||
regular=r"[^/@:#]"
|
||||
)
|
||||
_VCS_SSH_REGEX = (
|
||||
"^"
|
||||
"(?:(?P<user>{regular}*?)@)?"
|
||||
"(?P<host>{regular}*?)"
|
||||
":"
|
||||
"(?P<path>{regular}.*)?"
|
||||
"$".format(regular=r"[^/@:#]")
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -199,9 +197,9 @@ class CreateAndPopulate(object):
|
||||
|
||||
# if there is nothing to populate, return
|
||||
if not any([
|
||||
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
||||
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
||||
):
|
||||
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
|
||||
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
|
||||
):
|
||||
return task
|
||||
|
||||
# clear the script section
|
||||
@@ -219,7 +217,7 @@ class CreateAndPopulate(object):
|
||||
if self.cwd:
|
||||
self.cwd = self.cwd
|
||||
cwd = self.cwd if Path(self.cwd).is_dir() else (
|
||||
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
||||
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
|
||||
if not Path(cwd).is_dir():
|
||||
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
|
||||
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
|
||||
@@ -577,6 +575,7 @@ if __name__ == '__main__':
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
_sanitize_function=None, # type: Optional[Callable[[str], str]]
|
||||
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
|
||||
skip_global_imports=False # type: bool
|
||||
):
|
||||
# type: (...) -> Optional[Dict, Task]
|
||||
"""
|
||||
@@ -659,6 +658,9 @@ if __name__ == '__main__':
|
||||
return dill.loads(bytes_)
|
||||
:param _sanitize_function: Sanitization function for the function string.
|
||||
:param _sanitize_helper_functions: Sanitization function for the helper function string.
|
||||
:param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise
|
||||
all global imports will be automatically imported in a safe manner at the beginning of the function's
|
||||
execution. Default is False
|
||||
:return: Newly created Task object
|
||||
"""
|
||||
# not set -> equals True
|
||||
@@ -671,7 +673,7 @@ if __name__ == '__main__':
|
||||
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
|
||||
|
||||
function_source, function_name = CreateFromFunction.__extract_function_information(
|
||||
a_function, sanitize_function=_sanitize_function
|
||||
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
|
||||
)
|
||||
# add helper functions on top.
|
||||
for f in (helper_functions or []):
|
||||
@@ -846,11 +848,133 @@ if __name__ == '__main__':
|
||||
return function_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None):
|
||||
# type: (Callable, Optional[Callable]) -> (str, str)
|
||||
function_name = str(function.__name__)
|
||||
function_source = inspect.getsource(function)
|
||||
def __extract_imports(func):
|
||||
def add_import_guard(import_):
|
||||
return ("try:\n "
|
||||
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
|
||||
+ "\nexcept Exception as e:\n print('Import error: ' + str(e))\n"
|
||||
)
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
import ast
|
||||
func_module = inspect.getmodule(func)
|
||||
source = inspect.getsource(func_module)
|
||||
parsed_source = ast.parse(source)
|
||||
imports = []
|
||||
for parsed_source_entry in parsed_source.body:
|
||||
# we only include global imports (i.e. at col_offset 0)
|
||||
if parsed_source_entry.col_offset != 0:
|
||||
continue
|
||||
if isinstance(parsed_source_entry, ast.ImportFrom):
|
||||
for sub_entry in parsed_source_entry.names:
|
||||
import_str = "from {} import {}".format(parsed_source_entry.module, sub_entry.name)
|
||||
if sub_entry.asname:
|
||||
import_str += " as {}".format(sub_entry.asname)
|
||||
imports.append(import_str)
|
||||
elif isinstance(parsed_source_entry, ast.Import):
|
||||
for sub_entry in parsed_source_entry.names:
|
||||
import_str = "import {}".format(sub_entry.name)
|
||||
if sub_entry.asname:
|
||||
import_str += " as {}".format(sub_entry.asname)
|
||||
imports.append(import_str)
|
||||
imports = [add_import_guard(import_) for import_ in imports]
|
||||
return "\n".join(imports)
|
||||
except Exception as e:
|
||||
getLogger().warning('Could not fetch function imports: {}'.format(e))
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_wrapped(decorated):
|
||||
if not decorated.__closure__:
|
||||
return None
|
||||
closure = (c.cell_contents for c in decorated.__closure__)
|
||||
if not closure:
|
||||
return None
|
||||
return next((c for c in closure if inspect.isfunction(c)), None)
|
||||
|
||||
@staticmethod
|
||||
def _deep_extract_wrapped(decorated):
|
||||
while True:
|
||||
# noinspection PyProtectedMember
|
||||
func = CreateFromFunction._extract_wrapped(decorated)
|
||||
if not func:
|
||||
return decorated
|
||||
decorated = func
|
||||
|
||||
@staticmethod
|
||||
def __sanitize(func_source, sanitize_function=None):
|
||||
if sanitize_function:
|
||||
function_source = sanitize_function(function_source)
|
||||
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
|
||||
return function_source, function_name
|
||||
func_source = sanitize_function(func_source)
|
||||
return CreateFromFunction.__sanitize_remove_type_hints(func_source)
|
||||
|
||||
@staticmethod
|
||||
def __get_func_members(module):
|
||||
result = []
|
||||
try:
|
||||
import ast
|
||||
|
||||
source = inspect.getsource(module)
|
||||
parsed = ast.parse(source)
|
||||
for f in parsed.body:
|
||||
if isinstance(f, ast.FunctionDef):
|
||||
result.append(f.name)
|
||||
except Exception as e:
|
||||
name = getattr(module, "__name__", module)
|
||||
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __get_source_with_decorators(func, original_module=None, sanitize_function=None):
|
||||
if original_module is None:
|
||||
original_module = inspect.getmodule(func)
|
||||
func_members = CreateFromFunction.__get_func_members(original_module)
|
||||
try:
|
||||
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
|
||||
except Exception as e:
|
||||
name = getattr(original_module, "__name__", original_module)
|
||||
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
|
||||
func_members_dict = {}
|
||||
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
|
||||
decorated_func_source = CreateFromFunction.__sanitize(
|
||||
inspect.getsource(decorated_func),
|
||||
sanitize_function=sanitize_function
|
||||
)
|
||||
try:
|
||||
import ast
|
||||
|
||||
parsed_decorated = ast.parse(decorated_func_source)
|
||||
for body_elem in parsed_decorated.body:
|
||||
if not isinstance(body_elem, ast.FunctionDef):
|
||||
continue
|
||||
for decorator in body_elem.decorator_list:
|
||||
name = None
|
||||
if isinstance(decorator, ast.Name):
|
||||
name = decorator.id
|
||||
elif isinstance(decorator, ast.Call):
|
||||
name = decorator.func.id
|
||||
if not name:
|
||||
continue
|
||||
decorator_func = func_members_dict.get(name)
|
||||
if name not in func_members or not decorator_func:
|
||||
continue
|
||||
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
|
||||
decorator_func,
|
||||
original_module=original_module,
|
||||
sanitize_function=sanitize_function
|
||||
) + "\n\n" + decorated_func_source
|
||||
break
|
||||
except Exception as e:
|
||||
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
|
||||
return decorated_func_source
|
||||
|
||||
@staticmethod
|
||||
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
|
||||
# type: (Callable, Optional[Callable], bool) -> (str, str)
|
||||
function = CreateFromFunction._deep_extract_wrapped(function)
|
||||
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
|
||||
if not skip_global_imports:
|
||||
imports = CreateFromFunction.__extract_imports(function)
|
||||
else:
|
||||
imports = ""
|
||||
return imports + "\n" + function_source, function.__name__
|
||||
|
||||
@@ -97,7 +97,12 @@ class ScriptRequirements(object):
|
||||
for fname, lines in skimage.items():
|
||||
modules.add('scikit_image', fname, lines)
|
||||
|
||||
# if we have torch and it supports tensorboard, we should add that as well
|
||||
if 'tensorflow-intel' in modules:
|
||||
tfmodule = modules.pop('tensorflow-intel', {})
|
||||
for fname, lines in tfmodule.items():
|
||||
modules.add('tensorflow', fname, lines)
|
||||
|
||||
# if we have torch, and it supports tensorboard, we should add that as well
|
||||
# (because it will not be detected automatically)
|
||||
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
|
||||
# noinspection PyBroadException
|
||||
@@ -331,14 +336,14 @@ class _JupyterObserver(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters import PythonExporter
|
||||
from nbconvert.exporters import PythonExporter # noqa
|
||||
_script_exporter = PythonExporter()
|
||||
except Exception:
|
||||
_script_exporter = None
|
||||
|
||||
if _script_exporter is None:
|
||||
# noinspection PyPackageRequirements
|
||||
from nbconvert.exporters.script import ScriptExporter
|
||||
from nbconvert.exporters.script import ScriptExporter # noqa
|
||||
_script_exporter = ScriptExporter()
|
||||
|
||||
except Exception as ex:
|
||||
@@ -617,7 +622,7 @@ class ScriptInfo(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from notebook.notebookapp import list_running_servers # <= Notebook v6
|
||||
from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
jupyter_servers += list(list_running_servers())
|
||||
@@ -632,7 +637,7 @@ class ScriptInfo(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyPackageRequirements
|
||||
from jupyter_server.serverapp import list_running_servers
|
||||
from jupyter_server.serverapp import list_running_servers # noqa
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
jupyter_servers += list(list_running_servers())
|
||||
@@ -719,7 +724,7 @@ class ScriptInfo(object):
|
||||
is_google_colab = False
|
||||
log_history = False
|
||||
colab_name = None
|
||||
# check if this is google.colab, then there is no local file
|
||||
# check if this is `google.colab`, then there is no local file
|
||||
is_google_colab = ScriptInfo.is_google_colab()
|
||||
|
||||
if is_google_colab:
|
||||
@@ -748,7 +753,7 @@ class ScriptInfo(object):
|
||||
if not entry_point.exists():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb'
|
||||
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
|
||||
# now we should try to find the actual file
|
||||
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
|
||||
if not entry_point_alternative.is_file():
|
||||
@@ -823,7 +828,7 @@ class ScriptInfo(object):
|
||||
# returns tuple (notebook name, raw string notebook)
|
||||
# None, None if fails
|
||||
try:
|
||||
from google.colab import _message
|
||||
from google.colab import _message # noqa
|
||||
|
||||
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
|
||||
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
|
||||
@@ -990,6 +995,10 @@ class ScriptInfo(object):
|
||||
working_dir = cls._get_working_dir(repo_root)
|
||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||
|
||||
# check if we are running with torch distributed, or transformers accelerate
|
||||
# make sure we change the entry point to reflect it.
|
||||
entry_point = cls._detect_distributed_execution(entry_point, log)
|
||||
|
||||
if check_uncommitted:
|
||||
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
|
||||
if jupyter_filepath:
|
||||
@@ -1005,7 +1014,7 @@ class ScriptInfo(object):
|
||||
if len(diff) > cls.max_diff_size_bytes:
|
||||
messages.append(
|
||||
"======> WARNING! Git diff too large to store "
|
||||
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
|
||||
"({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
|
||||
auxiliary_git_diff = diff
|
||||
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
|
||||
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
|
||||
@@ -1060,8 +1069,54 @@ class ScriptInfo(object):
|
||||
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
|
||||
script_requirements)
|
||||
|
||||
@classmethod
|
||||
def _detect_distributed_execution(cls, entry_point, log):
|
||||
# check if we are running with torch distributed, or transformers accelerate
|
||||
# make sure we change the entry point to reflect it.
|
||||
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
|
||||
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
|
||||
if not is_torch_distributed and not is_transformers_distributed:
|
||||
return entry_point
|
||||
|
||||
# this torch distributed
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from psutil import Process # noqa
|
||||
cmdline = Process().parent().cmdline()
|
||||
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
|
||||
if is_torch_distributed:
|
||||
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
|
||||
elif is_transformers_distributed:
|
||||
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
|
||||
else:
|
||||
raise Exception() # we should not get here
|
||||
|
||||
cmdline = cmdline[cmdstart_i:]
|
||||
# reverse look into the paths
|
||||
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
|
||||
filearg = cmdline[cmdend_i]
|
||||
# notice --args (script args) are passed on the Args section, we skip detecting them here
|
||||
# we are also already removing the filearg from the cmd (it is the last before script args)
|
||||
new_cmd = cmdline[:cmdend_i]
|
||||
|
||||
# we assume our entrypoint is the last parameter of the execution cmd line
|
||||
if Path(filearg).stem == Path(entry_point).stem:
|
||||
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
|
||||
if log:
|
||||
log.info(
|
||||
"{} execution detected: adjusting entrypoint to "
|
||||
"reflect distributed execution arguments".format(
|
||||
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
|
||||
)
|
||||
except Exception:
|
||||
if log:
|
||||
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
|
||||
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
|
||||
|
||||
return entry_point
|
||||
|
||||
@staticmethod
|
||||
def __legacy_jupyter_notebook_server_json_parsing(self):
|
||||
def __legacy_jupyter_notebook_server_json_parsing():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# on some jupyter notebook versions this function can crash on parsing the json file,
|
||||
|
||||
@@ -1464,52 +1464,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
execution.docker_cmd = image + (' {}'.format(arguments) if arguments else '')
|
||||
self._edit(execution=execution)
|
||||
|
||||
def set_packages(self, packages):
|
||||
# type: (Union[str, Sequence[str]]) -> ()
|
||||
"""
|
||||
Manually specify a list of required packages or a local requirements.txt file.
|
||||
|
||||
:param packages: The list of packages or the path to the requirements.txt file.
|
||||
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt"
|
||||
"""
|
||||
if not packages:
|
||||
return
|
||||
if not isinstance(packages, str) or not os.path.exists(packages):
|
||||
# noinspection PyProtectedMember
|
||||
self._update_requirements(packages)
|
||||
return
|
||||
with open(packages) as f:
|
||||
# noinspection PyProtectedMember
|
||||
self._update_requirements([line.strip() for line in f.readlines()])
|
||||
|
||||
def set_repo(self, repo, branch=None, commit=None):
|
||||
# type: (str, Optional[str], Optional[str]) -> ()
|
||||
"""
|
||||
Specify a repository to attach to the function.
|
||||
Allow users to execute the task inside the specified repository, enabling them to load modules/script
|
||||
from the repository. Notice the execution work directory will be the repository root folder.
|
||||
Supports both git repo url link, and local repository path (automatically converted into the remote
|
||||
git/commit as is currently checkout).
|
||||
Example remote url: 'https://github.com/user/repo.git'.
|
||||
Example local repo copy: './repo' -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy.
|
||||
|
||||
:param repo: Remote URL for the repository to use, OR path to local copy of the git repository
|
||||
Example: 'https://github.com/allegroai/clearml.git' or '~/project/repo'
|
||||
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
"""
|
||||
if not repo:
|
||||
return
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
self.data.script.repository = repo
|
||||
if branch:
|
||||
self.data.script.branch = branch
|
||||
if commit:
|
||||
self.data.script.version_num = commit
|
||||
self._edit(script=self.data.script)
|
||||
|
||||
def get_base_docker(self):
|
||||
# type: () -> str
|
||||
"""Get the base Docker command (image) that is set for this experiment."""
|
||||
@@ -2326,7 +2280,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
import pkg_resources
|
||||
except ImportError:
|
||||
get_logger("task").warning(
|
||||
"Requirement file %s skipped since pkg_resources is not installed" % package_name)
|
||||
"Requirement file `{}` skipped since pkg_resources is not installed".format(package_name))
|
||||
else:
|
||||
with Path(package_name).open() as requirements_txt:
|
||||
for req in pkg_resources.parse_requirements(requirements_txt):
|
||||
|
||||
@@ -175,10 +175,11 @@ class PatchOsFork(object):
|
||||
try:
|
||||
return PatchOsFork._original_process_run(self, *args, **kwargs)
|
||||
finally:
|
||||
if task:
|
||||
if task and patched_worker:
|
||||
try:
|
||||
if patched_worker:
|
||||
# remove at exit hooks, we will deadlock when the
|
||||
# noinspection PyProtectedMember
|
||||
if task._report_subprocess_enabled:
|
||||
# just in case, remove at exit hooks, we will deadlock when the
|
||||
# main Pool manager will terminate this process, and it will...
|
||||
# noinspection PyProtectedMember
|
||||
task._at_exit_called = True
|
||||
@@ -214,12 +215,30 @@ class PatchOsFork(object):
|
||||
if not task:
|
||||
return
|
||||
|
||||
if not Task._report_subprocess_enabled:
|
||||
# https://stackoverflow.com/a/34507557
|
||||
# NOTICE: subprocesses do not exit through exit we have to register signals
|
||||
if task._Task__exit_hook:
|
||||
task._Task__exit_hook.register_signal_and_exception_hooks()
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
task._remove_signal_hooks()
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
if Task._report_subprocess_enabled:
|
||||
# noinspection PyProtectedMember
|
||||
task._remove_exception_hooks()
|
||||
|
||||
PatchOsFork._current_task = task
|
||||
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
|
||||
# noinspection PyProtectedMember
|
||||
if not bool(task._report_subprocess_enabled):
|
||||
BackgroundMonitor.start_all(task=task)
|
||||
|
||||
# if we are reporting into a subprocess, no need to further patch the exit functions
|
||||
if Task._report_subprocess_enabled:
|
||||
return
|
||||
|
||||
# The signal handler method is Not enough, for the time being, we have both
|
||||
# even though it makes little sense
|
||||
# # if we got here patch the os._exit of our instance to call us
|
||||
@@ -244,6 +263,10 @@ class PatchOsFork(object):
|
||||
# noinspection PyProtectedMember, PyUnresolvedReferences
|
||||
os._org_exit = os._exit
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
# https://stackoverflow.com/a/34507557
|
||||
# NOTICE: subprocesses do not exit through exit, and in most cases not with _exit,
|
||||
# this means at_exit calls are Not registered respected
|
||||
os._exit = _at_exit_callback
|
||||
|
||||
@staticmethod
|
||||
@@ -261,3 +284,23 @@ class PatchOsFork(object):
|
||||
PatchOsFork._fork_callback_after_child()
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def unpatch_fork():
|
||||
try:
|
||||
if PatchOsFork._original_fork and os._exit != PatchOsFork._original_fork:
|
||||
os._exit = PatchOsFork._original_fork
|
||||
PatchOsFork._original_fork = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def unpatch_process_run():
|
||||
try:
|
||||
from multiprocessing.process import BaseProcess
|
||||
|
||||
if PatchOsFork._original_process_run and BaseProcess.run != PatchOsFork._original_process_run:
|
||||
BaseProcess.run = PatchOsFork._original_process_run
|
||||
PatchOsFork._original_process_run = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -131,7 +131,7 @@ class PatchHydra(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
override = PatchHydra._parse_override(override)
|
||||
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
|
||||
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
|
||||
return group_exists
|
||||
except Exception:
|
||||
if not PatchHydra._config_group_warning_sent:
|
||||
|
||||
@@ -31,6 +31,7 @@ class PatchJsonArgParse(object):
|
||||
_special_fields = ["config", "subcommand"]
|
||||
_section_name = "Args"
|
||||
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
|
||||
_ignore_ui_overrides = "_ignore_ui_overrides_"
|
||||
__remote_task_params = {}
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
@@ -43,7 +44,7 @@ class PatchJsonArgParse(object):
|
||||
cls.patch(task)
|
||||
|
||||
@classmethod
|
||||
def patch(cls, task):
|
||||
def patch(cls, task=None):
|
||||
if ArgumentParser is None:
|
||||
return
|
||||
PatchJsonArgParse._update_task_args()
|
||||
@@ -72,7 +73,9 @@ class PatchJsonArgParse(object):
|
||||
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
|
||||
if isinstance(v, Namespace) or (
|
||||
isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)
|
||||
):
|
||||
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
|
||||
args_type[key_with_section] = PatchJsonArgParse.namespace_type
|
||||
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
||||
@@ -86,9 +89,9 @@ class PatchJsonArgParse(object):
|
||||
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||
if have_config_file:
|
||||
cls._current_task.set_parameter(
|
||||
cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
|
||||
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
|
||||
False,
|
||||
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
|
||||
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -110,8 +113,6 @@ class PatchJsonArgParse(object):
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||
if not PatchJsonArgParse._current_task:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
if len(args) == 1:
|
||||
kwargs["args"] = args[0]
|
||||
args = []
|
||||
@@ -124,13 +125,19 @@ class PatchJsonArgParse(object):
|
||||
params_namespace = Namespace()
|
||||
for k, v in params.items():
|
||||
params_namespace[k] = v
|
||||
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
|
||||
allow_jsonargparse_overrides_value = True
|
||||
if PatchJsonArgParse._allow_jsonargparse_overrides in params:
|
||||
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides)
|
||||
if PatchJsonArgParse._ignore_ui_overrides in params:
|
||||
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
|
||||
if not allow_jsonargparse_overrides_value:
|
||||
params_namespace = PatchJsonArgParse.__restore_args(
|
||||
obj,
|
||||
params_namespace,
|
||||
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||
)
|
||||
if PatchJsonArgParse._allow_jsonargparse_overrides in params_namespace:
|
||||
del params_namespace[PatchJsonArgParse._allow_jsonargparse_overrides]
|
||||
if PatchJsonArgParse._ignore_ui_overrides in params_namespace:
|
||||
del params_namespace[PatchJsonArgParse._ignore_ui_overrides]
|
||||
return params_namespace
|
||||
except Exception as e:
|
||||
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
|
||||
@@ -149,6 +156,7 @@ class PatchJsonArgParse(object):
|
||||
except ImportError:
|
||||
try:
|
||||
import pytorch_lightning
|
||||
|
||||
lightning = pytorch_lightning
|
||||
except ImportError:
|
||||
lightning = None
|
||||
@@ -178,20 +186,14 @@ class PatchJsonArgParse(object):
|
||||
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
|
||||
for key, section_param in cls.__remote_task_params[cls._section_name].items():
|
||||
if section_param.type == cls.namespace_type:
|
||||
params_dict[
|
||||
"{}/{}".format(cls._section_name, key)
|
||||
] = cls._get_namespace_from_json(section_param.value)
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value)
|
||||
elif section_param.type == cls.path_type:
|
||||
params_dict[
|
||||
"{}/{}".format(cls._section_name, key)
|
||||
] = cls._get_path_from_json(section_param.value)
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value)
|
||||
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
|
||||
params_dict["{}/{}".format(cls._section_name, key)] = None
|
||||
skip = len(cls._section_name) + 1
|
||||
cls.__remote_task_params_dict = {
|
||||
k[skip:]: v
|
||||
for k, v in params_dict.items()
|
||||
if k.startswith(cls._section_name + cls._args_sep)
|
||||
k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep)
|
||||
}
|
||||
cls.__update_remote_task_params_dict_based_on_paths(parser)
|
||||
|
||||
@@ -200,9 +202,7 @@ class PatchJsonArgParse(object):
|
||||
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
|
||||
for path in paths:
|
||||
args = PatchJsonArgParse.__get_args_from_path(
|
||||
parser,
|
||||
path,
|
||||
subcommand=cls.__remote_task_params_dict.get("subcommand")
|
||||
parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand")
|
||||
)
|
||||
for subarg_key, subarg_value in args.items():
|
||||
if subarg_key not in cls.__remote_task_params_dict:
|
||||
@@ -222,7 +222,12 @@ class PatchJsonArgParse(object):
|
||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||
if subcommand:
|
||||
parsed_cfg = {
|
||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
|
||||
(
|
||||
(subcommand + PatchJsonArgParse._commands_sep)
|
||||
if k not in PatchJsonArgParse._special_fields
|
||||
else ""
|
||||
)
|
||||
+ k: v
|
||||
for k, v in parsed_cfg.items()
|
||||
}
|
||||
return parsed_cfg
|
||||
@@ -252,3 +257,7 @@ class PatchJsonArgParse(object):
|
||||
if isinstance(json_, list):
|
||||
return [Path(**dict_) for dict_ in json_]
|
||||
return Path(**json_)
|
||||
|
||||
|
||||
# patch jsonargparse before anything else
|
||||
PatchJsonArgParse.patch()
|
||||
|
||||
@@ -186,7 +186,15 @@ def get_node_id(default=0):
|
||||
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
|
||||
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
|
||||
|
||||
# if node is is till None, use the default
|
||||
# if node is still None, use the global RANK
|
||||
if node_id is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
node_id = int(os.environ.get("RANK"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if node is still None, use the default
|
||||
if node_id is None:
|
||||
node_id = default
|
||||
|
||||
|
||||
@@ -355,6 +355,20 @@ def stdout_print(*args, **kwargs):
|
||||
sys.stdout.write(line)
|
||||
|
||||
|
||||
def debug_print(*args, **kwargs):
|
||||
"""
|
||||
Print directly to stdout, with process and timestamp from last print call
|
||||
Example: [pid=123, t=0.003] message here
|
||||
"""
|
||||
global tic
|
||||
tic = globals().get('tic', time.time())
|
||||
stdout_print(
|
||||
"\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time()-tic)
|
||||
+ str(args[0] if len(args) == 1 else ("" if not args else args)) + "\033[0m", **kwargs
|
||||
)
|
||||
tic = time.time()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# from clearml import Task
|
||||
# task = Task.init(project_name="examples", task_name="trace test")
|
||||
|
||||
@@ -26,7 +26,7 @@ from .storage.helper import StorageHelper
|
||||
from .utilities.plotly_reporter import SeriesInfo
|
||||
|
||||
# Make sure that DeprecationWarning within this package always gets printed
|
||||
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
|
||||
warnings.filterwarnings("always", category=DeprecationWarning, module=__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -56,9 +56,10 @@ class Logger(object):
|
||||
|
||||
|
||||
"""
|
||||
|
||||
SeriesInfo = SeriesInfo
|
||||
_tensorboard_logging_auto_group_scalars = False
|
||||
_tensorboard_single_series_per_graph = deferred_config('metrics.tensorboard_single_series_per_graph', False)
|
||||
_tensorboard_single_series_per_graph = deferred_config("metrics.tensorboard_single_series_per_graph", False)
|
||||
|
||||
def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
|
||||
"""
|
||||
@@ -66,8 +67,9 @@ class Logger(object):
|
||||
**Do not construct Logger manually!**
|
||||
Please use :meth:`Logger.get_current`
|
||||
"""
|
||||
assert isinstance(private_task, _Task), \
|
||||
'Logger object cannot be instantiated externally, use Logger.current_logger()'
|
||||
assert isinstance(
|
||||
private_task, _Task
|
||||
), "Logger object cannot be instantiated externally, use Logger.current_logger()"
|
||||
super(Logger, self).__init__()
|
||||
self._task = private_task
|
||||
self._default_upload_destination = None
|
||||
@@ -75,16 +77,19 @@ class Logger(object):
|
||||
self._report_worker = None
|
||||
self._graph_titles = {}
|
||||
self._tensorboard_series_force_prefix = None
|
||||
self._task_handler = TaskHandler(task=self._task, capacity=100) \
|
||||
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging) else None
|
||||
self._task_handler = (
|
||||
TaskHandler(task=self._task, capacity=100)
|
||||
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging)
|
||||
else None
|
||||
)
|
||||
self._connect_std_streams = connect_stdout or connect_stderr
|
||||
self._connect_logging = connect_logging
|
||||
self._default_max_sample_history = None
|
||||
|
||||
# Make sure urllib is never in debug/info,
|
||||
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
|
||||
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
disable_urllib3_info = config.get("log.disable_urllib3_info", True)
|
||||
if disable_urllib3_info and logging.getLogger("urllib3").isEnabledFor(logging.INFO):
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
|
||||
if self._task.is_main_task():
|
||||
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
|
||||
@@ -112,6 +117,7 @@ class Logger(object):
|
||||
:return: The Logger object (a singleton) for the current running Task.
|
||||
"""
|
||||
from .task import Task
|
||||
|
||||
task = Task.current_task()
|
||||
if not task:
|
||||
return None
|
||||
@@ -181,20 +187,20 @@ class Logger(object):
|
||||
:param name: Metric's name
|
||||
:param value: Metric's value
|
||||
"""
|
||||
return self.report_scalar(title="Summary", series=name, value=value, iteration=-2**31)
|
||||
return self.report_scalar(title="Summary", series=name, value=value, iteration=-(2**31))
|
||||
|
||||
def report_vector(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
values, # type: Sequence[Union[int, float]]
|
||||
iteration=None, # type: Optional[int]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
mode=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
values, # type: Sequence[Union[int, float]]
|
||||
iteration=None, # type: Optional[int]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
mode=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot a vector as (default stacked) histogram.
|
||||
@@ -229,27 +235,36 @@ class Logger(object):
|
||||
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
|
||||
"""
|
||||
warnings.warn(
|
||||
":meth:`Logger.report_vector` is deprecated;"
|
||||
"use :meth:`Logger.report_histogram` instead.",
|
||||
DeprecationWarning
|
||||
":meth:`Logger.report_vector` is deprecated; use :meth:`Logger.report_histogram` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self._touch_title_series(title, series)
|
||||
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
|
||||
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
|
||||
return self.report_histogram(
|
||||
title,
|
||||
series,
|
||||
values,
|
||||
iteration or 0,
|
||||
labels=labels,
|
||||
xlabels=xlabels,
|
||||
xaxis=xaxis,
|
||||
yaxis=yaxis,
|
||||
mode=mode,
|
||||
extra_layout=extra_layout
|
||||
)
|
||||
|
||||
def report_histogram(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
values, # type: Sequence[Union[int, float]]
|
||||
iteration=None, # type: Optional[int]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
mode=None, # type: Optional[str]
|
||||
data_args=None, # type: Optional[dict]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
values, # type: Sequence[Union[int, float]]
|
||||
iteration=None, # type: Optional[int]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
mode=None, # type: Optional[str]
|
||||
data_args=None, # type: Optional[dict]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot a (default grouped) histogram.
|
||||
@@ -301,21 +316,21 @@ class Logger(object):
|
||||
xlabels=xlabels,
|
||||
xtitle=xaxis,
|
||||
ytitle=yaxis,
|
||||
mode=mode or 'group',
|
||||
mode=mode or "group",
|
||||
data_args=data_args,
|
||||
layout_config=extra_layout,
|
||||
)
|
||||
|
||||
def report_table(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
|
||||
csv=None, # type: Optional[str]
|
||||
url=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
extra_data=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
|
||||
csv=None, # type: Optional[str]
|
||||
url=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
extra_data=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, report a table plot.
|
||||
@@ -374,10 +389,7 @@ class Logger(object):
|
||||
)
|
||||
|
||||
"""
|
||||
mutually_exclusive(
|
||||
UsageError, _check_none=True,
|
||||
table_plot=table_plot, csv=csv, url=url
|
||||
)
|
||||
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
|
||||
table = table_plot
|
||||
if url or csv:
|
||||
if not pd:
|
||||
@@ -412,16 +424,16 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_line_plot(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: Sequence[SeriesInfo]
|
||||
xaxis, # type: str
|
||||
yaxis, # type: str
|
||||
mode='lines', # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
reverse_xaxis=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: Sequence[SeriesInfo]
|
||||
xaxis, # type: str
|
||||
yaxis, # type: str
|
||||
mode="lines", # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
reverse_xaxis=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot one or more series as lines.
|
||||
@@ -464,7 +476,7 @@ class Logger(object):
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
self._touch_title_series(title, series[0].name if series else '')
|
||||
self._touch_title_series(title, series[0].name if series else "")
|
||||
# noinspection PyProtectedMember
|
||||
return self._task._reporter.report_line_plot(
|
||||
title=title,
|
||||
@@ -479,17 +491,17 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_scatter2d(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
scatter, # type: Union[Sequence[Tuple[float, float]], np.ndarray]
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
mode='lines', # type: str
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
scatter, # type: Union[Sequence[Tuple[float, float]], np.ndarray]
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
mode="lines", # type: str
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, report a 2d scatter plot.
|
||||
@@ -558,19 +570,19 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_scatter3d(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
scatter, # type: Union[Sequence[Tuple[float, float, float]], np.ndarray]
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
zaxis=None, # type: Optional[str]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
mode='markers', # type: str
|
||||
fill=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
scatter, # type: Union[Sequence[Tuple[float, float, float]], np.ndarray]
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
zaxis=None, # type: Optional[str]
|
||||
labels=None, # type: Optional[List[str]]
|
||||
mode="markers", # type: str
|
||||
fill=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot a 3d scatter graph (with markers).
|
||||
@@ -605,16 +617,9 @@ class Logger(object):
|
||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||
"""
|
||||
# check if multiple series
|
||||
multi_series = (
|
||||
isinstance(scatter, list)
|
||||
and (
|
||||
isinstance(scatter[0], np.ndarray)
|
||||
or (
|
||||
scatter[0]
|
||||
and isinstance(scatter[0], list)
|
||||
and isinstance(scatter[0][0], list)
|
||||
)
|
||||
)
|
||||
multi_series = isinstance(scatter, list) and (
|
||||
isinstance(scatter[0], np.ndarray)
|
||||
or (scatter[0] and isinstance(scatter[0], list) and isinstance(scatter[0][0], list))
|
||||
)
|
||||
|
||||
if not multi_series:
|
||||
@@ -647,18 +652,18 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_confusion_matrix(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
yaxis_reversed=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
yaxis_reversed=False, # type: bool
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot a heat-map matrix.
|
||||
@@ -690,7 +695,7 @@ class Logger(object):
|
||||
matrix = np.array(matrix)
|
||||
|
||||
if extra_layout is None:
|
||||
extra_layout = {'texttemplate': '%{z}'}
|
||||
extra_layout = {"texttemplate": "%{z}"}
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
@@ -711,17 +716,17 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_matrix(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
yaxis_reversed=False, # type: bool
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
yaxis_reversed=False, # type: bool
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, plot a confusion matrix.
|
||||
@@ -744,30 +749,37 @@ class Logger(object):
|
||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||
"""
|
||||
warnings.warn(
|
||||
":meth:`Logger.report_matrix` is deprecated;"
|
||||
"use :meth:`Logger.report_confusion_matrix` instead.",
|
||||
":meth:`Logger.report_matrix` is deprecated;" "use :meth:`Logger.report_confusion_matrix` instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
self._touch_title_series(title, series)
|
||||
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
|
||||
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,
|
||||
yaxis_reversed=yaxis_reversed,
|
||||
extra_layout=extra_layout)
|
||||
return self.report_confusion_matrix(
|
||||
title,
|
||||
series,
|
||||
matrix,
|
||||
iteration or 0,
|
||||
xaxis=xaxis,
|
||||
yaxis=yaxis,
|
||||
xlabels=xlabels,
|
||||
ylabels=ylabels,
|
||||
yaxis_reversed=yaxis_reversed,
|
||||
extra_layout=extra_layout
|
||||
)
|
||||
|
||||
def report_surface(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
zaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
camera=None, # type: Optional[Sequence[float]]
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
matrix, # type: np.ndarray
|
||||
iteration=None, # type: Optional[int]
|
||||
xaxis=None, # type: Optional[str]
|
||||
yaxis=None, # type: Optional[str]
|
||||
zaxis=None, # type: Optional[str]
|
||||
xlabels=None, # type: Optional[List[str]]
|
||||
ylabels=None, # type: Optional[List[str]]
|
||||
camera=None, # type: Optional[Sequence[float]]
|
||||
comment=None, # type: Optional[str]
|
||||
extra_layout=None, # type: Optional[dict]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, report a 3d surface plot.
|
||||
@@ -821,16 +833,16 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_image(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
local_path=None, # type: Optional[str]
|
||||
image=None, # type: Optional[Union[np.ndarray, Image.Image]]
|
||||
matrix=None, # type: Optional[np.ndarray]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
url=None # type: Optional[str]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
local_path=None, # type: Optional[str]
|
||||
image=None, # type: Optional[Union[np.ndarray, Image.Image]]
|
||||
matrix=None, # type: Optional[np.ndarray]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
url=None, # type: Optional[str]
|
||||
):
|
||||
"""
|
||||
For explicit reporting, report an image and upload its contents.
|
||||
@@ -875,8 +887,7 @@ class Logger(object):
|
||||
- ``False`` - Do not delete after upload. (default)
|
||||
"""
|
||||
mutually_exclusive(
|
||||
UsageError, _check_none=True,
|
||||
local_path=local_path or None, url=url or None, image=image, matrix=matrix
|
||||
UsageError, _check_none=True, local_path=local_path or None, url=url or None, image=image, matrix=matrix
|
||||
)
|
||||
if matrix is not None:
|
||||
warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning)
|
||||
@@ -902,7 +913,7 @@ class Logger(object):
|
||||
else:
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri = Path(get_cache_dir()) / "debug_images"
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
# Verify that we can upload to this destination
|
||||
upload_uri = str(upload_uri)
|
||||
@@ -925,16 +936,16 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_media(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
local_path=None, # type: Optional[str]
|
||||
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
|
||||
file_extension=None, # type: Optional[str]
|
||||
max_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
url=None # type: Optional[str]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
local_path=None, # type: Optional[str]
|
||||
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
|
||||
file_extension=None, # type: Optional[str]
|
||||
max_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
url=None, # type: Optional[str]
|
||||
):
|
||||
"""
|
||||
Report media upload its contents, including images, audio, and video.
|
||||
@@ -966,8 +977,11 @@ class Logger(object):
|
||||
|
||||
"""
|
||||
mutually_exclusive(
|
||||
UsageError, _check_none=True,
|
||||
local_path=local_path or None, url=url or None, stream=stream,
|
||||
UsageError,
|
||||
_check_none=True,
|
||||
local_path=local_path or None,
|
||||
url=url or None,
|
||||
stream=stream
|
||||
)
|
||||
if stream is not None and not file_extension:
|
||||
raise ValueError("No file extension provided for stream media upload")
|
||||
@@ -989,7 +1003,7 @@ class Logger(object):
|
||||
else:
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri = Path(get_cache_dir()) / "debug_images"
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
# Verify that we can upload to this destination
|
||||
upload_uri = str(upload_uri)
|
||||
@@ -1009,11 +1023,11 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_plotly(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
figure, # type: Union[Dict, "Figure"] # noqa: F821
|
||||
iteration=None, # type: Optional[int]
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
figure, # type: Union[Dict, "Figure"] # noqa: F821
|
||||
iteration=None, # type: Optional[int]
|
||||
):
|
||||
"""
|
||||
Report a ``Plotly`` figure (plot) directly
|
||||
@@ -1033,7 +1047,7 @@ class Logger(object):
|
||||
plot = figure if isinstance(figure, dict) else figure.to_plotly_json()
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
plot['layout']['title'] = series
|
||||
plot["layout"]["title"] = series
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyProtectedMember
|
||||
@@ -1045,13 +1059,13 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def report_matplotlib_figure(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
figure, # type: Union[MatplotlibFigure, pyplot]
|
||||
iteration=None, # type: Optional[int]
|
||||
report_image=False, # type: bool
|
||||
report_interactive=True, # type: bool
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
figure, # type: Union[MatplotlibFigure, pyplot]
|
||||
iteration=None, # type: Optional[int]
|
||||
report_image=False, # type: bool
|
||||
report_interactive=True, # type: bool
|
||||
):
|
||||
"""
|
||||
Report a ``matplotlib`` figure / plot directly
|
||||
@@ -1078,8 +1092,7 @@ class Logger(object):
|
||||
figure=figure,
|
||||
iter=iteration or 0,
|
||||
logger=self,
|
||||
force_save_as_image=False if report_interactive and not report_image
|
||||
else ('png' if report_image else True),
|
||||
force_save_as_image=False if report_interactive and not report_image else ("png" if report_image else True),
|
||||
)
|
||||
|
||||
def set_default_upload_destination(self, uri):
|
||||
@@ -1201,21 +1214,28 @@ class Logger(object):
|
||||
return int(UploadEvent._file_history_size)
|
||||
|
||||
def report_image_and_upload(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
matrix=None, # type: Optional[Union[np.ndarray, Image.Image]]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False # type: bool
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
matrix=None, # type: Optional[Union[np.ndarray, Image.Image]]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
):
|
||||
"""
|
||||
.. deprecated:: 0.13.0
|
||||
Use :meth:`Logger.report_image` instead
|
||||
"""
|
||||
self.report_image(title=title, series=series, iteration=iteration or 0, local_path=path, image=matrix,
|
||||
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
|
||||
self.report_image(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=iteration or 0,
|
||||
local_path=path,
|
||||
image=matrix,
|
||||
max_image_history=max_image_history,
|
||||
delete_after_upload=delete_after_upload
|
||||
)
|
||||
|
||||
def capture_logging(self):
|
||||
# type: () -> "_LoggingContext"
|
||||
@@ -1224,6 +1244,7 @@ class Logger(object):
|
||||
|
||||
:return: a ContextManager
|
||||
"""
|
||||
|
||||
class _LoggingContext(object):
|
||||
def __init__(self, a_logger):
|
||||
self.logger = a_logger
|
||||
@@ -1285,6 +1306,7 @@ class Logger(object):
|
||||
:param force: If True, all matplotlib figures are converted automatically to non-interactive plots.
|
||||
"""
|
||||
from clearml.backend_interface.metrics import Reporter
|
||||
|
||||
Reporter.matplotlib_force_report_non_interactive(force=force)
|
||||
|
||||
@classmethod
|
||||
@@ -1327,8 +1349,7 @@ class Logger(object):
|
||||
try:
|
||||
return int(level)
|
||||
except (TypeError, ValueError):
|
||||
self._task.log.log(level=logging.ERROR,
|
||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
self._task.log.log(level=logging.ERROR, msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
return logging.INFO
|
||||
|
||||
def _console(self, msg, level=logging.INFO, omit_console=False, force_send=False, *args, **_):
|
||||
@@ -1356,7 +1377,7 @@ class Logger(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record = self._task.log.makeRecord(
|
||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||
"console", level=level, fn="", lno=0, func="", msg=msg, args=args, exc_info=None
|
||||
)
|
||||
# find the task handler that matches our task
|
||||
self._task_handler.emit(record)
|
||||
@@ -1366,7 +1387,8 @@ class Logger(object):
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
StdStreamPatch.stderr_original_write(
|
||||
'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg))
|
||||
'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
@@ -1379,7 +1401,7 @@ class Logger(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
StdStreamPatch.stdout_original_write(str(msg) + '\n')
|
||||
StdStreamPatch.stdout_original_write(str(msg) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
@@ -1389,14 +1411,14 @@ class Logger(object):
|
||||
self._start_task_if_needed()
|
||||
|
||||
def _report_image_plot_and_upload(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
matrix=None, # type: Optional[np.ndarray]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False # type: bool
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
matrix=None, # type: Optional[np.ndarray]
|
||||
max_image_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
):
|
||||
"""
|
||||
Report an image, upload its contents, and present in plots section using plotly
|
||||
@@ -1418,7 +1440,7 @@ class Logger(object):
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri = Path(get_cache_dir()) / "debug_images"
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
# Verify that we can upload to this destination
|
||||
upload_uri = str(upload_uri)
|
||||
@@ -1438,13 +1460,13 @@ class Logger(object):
|
||||
)
|
||||
|
||||
def _report_file_and_upload(
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
max_file_history=None, # type: Optional[int]
|
||||
delete_after_upload=False # type: bool
|
||||
self,
|
||||
title, # type: str
|
||||
series, # type: str
|
||||
iteration=None, # type: Optional[int]
|
||||
path=None, # type: Optional[str]
|
||||
max_file_history=None, # type: Optional[int]
|
||||
delete_after_upload=False, # type: bool
|
||||
):
|
||||
"""
|
||||
Upload a file and report it as link in the debug images section.
|
||||
@@ -1465,7 +1487,7 @@ class Logger(object):
|
||||
self._start_task_if_needed()
|
||||
upload_uri = self.get_default_upload_destination()
|
||||
if not upload_uri:
|
||||
upload_uri = Path(get_cache_dir()) / 'debug_images'
|
||||
upload_uri = Path(get_cache_dir()) / "debug_images"
|
||||
upload_uri.mkdir(parents=True, exist_ok=True)
|
||||
# Verify that we can upload to this destination
|
||||
upload_uri = str(upload_uri)
|
||||
|
||||
@@ -2347,6 +2347,7 @@ class OutputModel(BaseModel):
|
||||
iteration=None, # type: Optional[int]
|
||||
update_comment=True, # type: bool
|
||||
is_package=False, # type: bool
|
||||
async_enable=True, # type: bool
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
@@ -2374,6 +2375,8 @@ class OutputModel(BaseModel):
|
||||
- ``True`` - Update model comment (Default)
|
||||
- ``False`` - Do not update
|
||||
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
|
||||
:param bool async_enable: Whether to upload model in background or to block.
|
||||
Will raise an error in the main thread if the weights failed to be uploaded or not.
|
||||
|
||||
:return: The uploaded URI.
|
||||
"""
|
||||
@@ -2421,6 +2424,7 @@ class OutputModel(BaseModel):
|
||||
target_filename=target_filename or Path(weights_filename).name,
|
||||
auto_delete_file=auto_delete_file,
|
||||
iteration=iteration,
|
||||
async_enable=async_enable
|
||||
)
|
||||
|
||||
# make sure we delete the previous file, if it exists
|
||||
@@ -2502,7 +2506,7 @@ class OutputModel(BaseModel):
|
||||
output_uri = model.update_and_upload(
|
||||
model_file=weights_filename,
|
||||
task_id=self._task.id,
|
||||
async_enable=True,
|
||||
async_enable=async_enable,
|
||||
target_filename=target_filename,
|
||||
framework=self.framework or framework,
|
||||
comment=comment,
|
||||
@@ -2535,6 +2539,7 @@ class OutputModel(BaseModel):
|
||||
target_filename=None, # type: Optional[str]
|
||||
auto_delete_file=True, # type: bool
|
||||
iteration=None, # type: Optional[int]
|
||||
async_enable=True, # type: bool
|
||||
):
|
||||
# type: (...) -> str
|
||||
"""
|
||||
@@ -2559,6 +2564,8 @@ class OutputModel(BaseModel):
|
||||
- ``False`` - Do not delete
|
||||
|
||||
:param int iteration: The iteration number.
|
||||
:param bool async_enable: Whether to upload model in background or to block.
|
||||
Will raise an error in the main thread if the weights failed to be uploaded or not.
|
||||
|
||||
:return: The uploaded URI for the weights package.
|
||||
"""
|
||||
@@ -2626,6 +2633,7 @@ class OutputModel(BaseModel):
|
||||
target_filename=target_filename or "model_package.zip",
|
||||
iteration=iteration,
|
||||
update_comment=False,
|
||||
async_enable=async_enable
|
||||
)
|
||||
# set the model tag (by now we should have a model object) so we know we have packaged file
|
||||
self._set_package_tag()
|
||||
|
||||
@@ -4,12 +4,19 @@ from time import time
|
||||
from typing import Optional, AnyStr, IO
|
||||
from ..config import config
|
||||
|
||||
try:
|
||||
from tqdm import tqdm # noqa
|
||||
except ImportError:
|
||||
tqdm = None
|
||||
|
||||
|
||||
class ProgressReport(object):
|
||||
report_upload_chunk_size_mb = None
|
||||
report_download_chunk_size_mb = None
|
||||
|
||||
def __init__(self, verbose, total_size, log, report_chunk_size_mb):
|
||||
def __init__(self, verbose, total_size, log, report_chunk_size_mb,
|
||||
description_prefix=None, description_suffix=None,
|
||||
max_time_between_reports_sec=10.0, report_start=None):
|
||||
self.current_status_mb = 0.
|
||||
self.last_reported = 0.
|
||||
self._tic = time()
|
||||
@@ -18,45 +25,117 @@ class ProgressReport(object):
|
||||
self._log = log
|
||||
self._log_flag = False
|
||||
self._total_size = total_size
|
||||
self._description_prefix = description_prefix
|
||||
self._description_suffix = description_suffix
|
||||
self._max_time_between_reports_sec = max_time_between_reports_sec
|
||||
self._report_start = report_start if report_start is not None else bool(tqdm is not None)
|
||||
self._tqdm = None
|
||||
self._tqdm_init = False
|
||||
|
||||
def close(self, report_completed=False, report_summary=False, report_prefix=None, report_suffix=None):
|
||||
# call this one when we are done
|
||||
if self._tqdm is not None:
|
||||
# if we created a self._tqdm object we need to close it
|
||||
if report_completed:
|
||||
self._tqdm.update(
|
||||
self._tqdm.total - min(self._tqdm.total, self.last_reported)
|
||||
)
|
||||
self._tqdm.close()
|
||||
self._tqdm = None
|
||||
|
||||
if report_summary:
|
||||
self._log.info(
|
||||
'{} {:.2f} MB successfully {}'.format(
|
||||
report_prefix or self._description_prefix, self._total_size,
|
||||
report_suffix or self._description_suffix).strip()
|
||||
)
|
||||
|
||||
def _get_tqdm(self):
|
||||
if self._tqdm_init:
|
||||
return self._tqdm
|
||||
|
||||
self._tqdm_init = True
|
||||
|
||||
# create the tqdm progress bar
|
||||
if tqdm:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._tqdm = tqdm(
|
||||
total=round(float(self._total_size), 2),
|
||||
# desc="{} {}".format(description_prefix, description_suffix).strip(),
|
||||
unit="MB",
|
||||
unit_scale=False,
|
||||
ncols=80,
|
||||
bar_format="{bar} {percentage:3.0f}% | {n_fmt}/{total_fmt} MB "
|
||||
"[{elapsed}<{remaining}, {rate_fmt}{postfix}]: {desc}",
|
||||
)
|
||||
except Exception:
|
||||
# failed initializing TQDM (maybe interface changed?)
|
||||
self._tqdm = None
|
||||
|
||||
return self._tqdm
|
||||
|
||||
def __call__(self, chunk_size, *_, **__):
|
||||
chunk_size /= 1024. * 1024.
|
||||
self.current_status_mb += chunk_size
|
||||
last_part = self.current_status_mb - self.last_reported
|
||||
|
||||
if self._verbose or (last_part >= self._report_chunk_size):
|
||||
if (self._verbose or (last_part >= self._report_chunk_size) or
|
||||
(self.last_reported and self.current_status_mb >= self._total_size-0.01) or
|
||||
(time()-self._tic > self._max_time_between_reports_sec)):
|
||||
time_diff = time() - self._tic
|
||||
self.speed = (last_part / time_diff) if time_diff != 0 else 0
|
||||
self._report(self._total_size, self.current_status_mb, self.speed)
|
||||
self._tic = time()
|
||||
self.last_reported = self.current_status_mb
|
||||
self._report(self._total_size, self.current_status_mb, self.speed)
|
||||
|
||||
def _report(self, total_mb, current_mb, speed_mbps):
|
||||
# type: (float, float, float) -> None
|
||||
pass
|
||||
if self._report_start and self.last_reported <= 0:
|
||||
# first time - print before initializing the tqdm bar
|
||||
self._log.info(
|
||||
"{}: {:.2f}MB {}".format(
|
||||
self._description_prefix, total_mb, self._description_suffix).strip(" :")
|
||||
)
|
||||
|
||||
# initialize or reuse the bar
|
||||
_tqdm = self._get_tqdm()
|
||||
if _tqdm:
|
||||
# make sure we do not spill over due to rounding
|
||||
if round(float(current_mb), 2) >= _tqdm.total:
|
||||
_tqdm.update(_tqdm.total - self.last_reported)
|
||||
else:
|
||||
_tqdm.update(current_mb - self.last_reported)
|
||||
else:
|
||||
self._log.info(
|
||||
"{}: {:.2f}MB / {:.2f}MB @ {:.2f}MBs {}".format(
|
||||
self._description_prefix,
|
||||
current_mb,
|
||||
total_mb,
|
||||
speed_mbps,
|
||||
self._description_suffix
|
||||
).strip(" :")
|
||||
)
|
||||
|
||||
|
||||
class UploadProgressReport(ProgressReport):
|
||||
def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=None):
|
||||
def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=None, report_start=None):
|
||||
report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
|
||||
else ProgressReport.report_upload_chunk_size_mb or \
|
||||
int(config.get("storage.log.report_upload_chunk_size_mb", 5))
|
||||
super(UploadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
|
||||
self._filename = filename
|
||||
|
||||
def _report(self, total_mb, current_mb, speed_mbps):
|
||||
# type: (float, float, float) -> None
|
||||
self._log.info(
|
||||
'Uploading: %.2fMB / %.2fMB @ %.2fMBs from %s' %
|
||||
(current_mb, total_mb, speed_mbps, self._filename)
|
||||
super(UploadProgressReport, self).__init__(
|
||||
verbose, total_size, log, report_chunk_size_mb,
|
||||
description_prefix="Uploading", description_suffix="to {}".format(filename),
|
||||
report_start=report_start,
|
||||
)
|
||||
self._filename = filename
|
||||
|
||||
@classmethod
|
||||
def from_stream(cls, stream, filename, verbose, log):
|
||||
# type: (IO[AnyStr], str, bool, logging.Logger) -> Optional[UploadProgressReport]
|
||||
if hasattr(stream, 'seek'):
|
||||
total_size = cls._get_stream_length(stream)
|
||||
return UploadProgressReport(filename, verbose, total_size, log)
|
||||
total_size_mb = cls._get_stream_length(stream) // (1024 * 1024)
|
||||
return UploadProgressReport(filename, verbose, total_size_mb, log)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filename, verbose, log):
|
||||
@@ -78,14 +157,13 @@ class UploadProgressReport(ProgressReport):
|
||||
|
||||
|
||||
class DownloadProgressReport(ProgressReport):
|
||||
def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=None):
|
||||
def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=None, report_start=None):
|
||||
report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
|
||||
else ProgressReport.report_download_chunk_size_mb or \
|
||||
int(config.get("storage.log.report_download_chunk_size_mb", 5))
|
||||
super(DownloadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
|
||||
super(DownloadProgressReport, self).__init__(
|
||||
verbose, total_size, log, report_chunk_size_mb,
|
||||
description_prefix="Downloading", description_suffix="from {}".format(remote_path),
|
||||
report_start=report_start,
|
||||
)
|
||||
self._remote_path = remote_path
|
||||
|
||||
def _report(self, total_mb, current_mb, speed_mbps):
|
||||
# type: (float, float, float) -> None
|
||||
self._log.info('Downloading: %.2fMB / %.2fMB @ %.2fMBs from %s' %
|
||||
(current_mb, total_mb, speed_mbps, self._remote_path))
|
||||
|
||||
@@ -615,7 +615,11 @@ class _Boto3Driver(_Driver):
|
||||
def async_download(a_obj, a_stream, cb, cfg):
|
||||
try:
|
||||
a_obj.download_fileobj(a_stream, Callback=cb, Config=cfg)
|
||||
if cb:
|
||||
cb.close(report_completed=True)
|
||||
except Exception as ex:
|
||||
if cb:
|
||||
cb.close()
|
||||
(log or self.get_logger()).error('Failed downloading: %s' % ex)
|
||||
a_stream.close()
|
||||
|
||||
@@ -780,8 +784,8 @@ class _GoogleCloudStorageDriver(_Driver):
|
||||
class _Container(object):
|
||||
def __init__(self, name, cfg):
|
||||
try:
|
||||
from google.cloud import storage
|
||||
from google.oauth2 import service_account
|
||||
from google.cloud import storage # noqa
|
||||
from google.oauth2 import service_account # noqa
|
||||
except ImportError:
|
||||
raise UsageError(
|
||||
'Google cloud driver not found. '
|
||||
@@ -862,7 +866,7 @@ class _GoogleCloudStorageDriver(_Driver):
|
||||
object.delete()
|
||||
except Exception as ex:
|
||||
try:
|
||||
from google.cloud.exceptions import NotFound
|
||||
from google.cloud.exceptions import NotFound # noqa
|
||||
if isinstance(ex, NotFound):
|
||||
return False
|
||||
except ImportError:
|
||||
@@ -949,7 +953,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
except ImportError:
|
||||
try:
|
||||
from azure.storage.blob import BlockBlobService # noqa
|
||||
from azure.common import AzureHttpError # noqa: F401
|
||||
from azure.common import AzureHttpError # noqa
|
||||
|
||||
self.__legacy = True
|
||||
except ImportError:
|
||||
@@ -1193,6 +1197,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
|
||||
obj.blob_name,
|
||||
progress_callback=cb,
|
||||
)
|
||||
cb.close()
|
||||
if container.is_legacy():
|
||||
return blob.content
|
||||
else:
|
||||
@@ -1663,7 +1668,7 @@ class _FileStorageDriver(_Driver):
|
||||
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception:
|
||||
except Exception: # noqa
|
||||
return False
|
||||
|
||||
# # Check and delete all the empty parent folders
|
||||
@@ -1767,14 +1772,14 @@ class _FileStorageDriver(_Driver):
|
||||
if six.PY3:
|
||||
from io import FileIO as file
|
||||
|
||||
if isinstance(iterator, (file)):
|
||||
if isinstance(iterator, file):
|
||||
get_data = iterator.read
|
||||
args = (chunk_size,)
|
||||
else:
|
||||
get_data = next
|
||||
args = (iterator,)
|
||||
|
||||
data = bytes('')
|
||||
data = bytes(b'')
|
||||
empty = False
|
||||
|
||||
while not empty or len(data) > 0:
|
||||
@@ -2320,7 +2325,7 @@ class StorageHelper(object):
|
||||
return self._get_object_size_bytes(obj, silence_errors)
|
||||
|
||||
def _get_object_size_bytes(self, obj, silence_errors=False):
|
||||
# type: (object) -> [int, None]
|
||||
# type: (object, bool) -> [int, None]
|
||||
"""
|
||||
Auxiliary function for `get_object_size_bytes`.
|
||||
Get size of the remote object in bytes.
|
||||
@@ -2448,6 +2453,10 @@ class StorageHelper(object):
|
||||
stream.seek(0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cb:
|
||||
cb.close(report_completed=not bool(last_ex))
|
||||
|
||||
if last_ex:
|
||||
raise last_ex
|
||||
|
||||
@@ -2601,9 +2610,10 @@ class StorageHelper(object):
|
||||
return direct_access_path
|
||||
|
||||
temp_local_path = None
|
||||
cb = None
|
||||
try:
|
||||
if verbose:
|
||||
self._log.info('Start downloading from %s' % remote_path)
|
||||
self._log.info("Start downloading from {}".format(remote_path))
|
||||
if not overwrite_existing and Path(local_path).is_file():
|
||||
self._log.debug(
|
||||
'File {} already exists, no need to download, thread id = {}'.format(
|
||||
@@ -2643,8 +2653,9 @@ class StorageHelper(object):
|
||||
|
||||
# if driver supports download with callback, use it (it might be faster)
|
||||
if hasattr(self._driver, 'download_object'):
|
||||
# callback
|
||||
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self._log)
|
||||
# callback if verbose we already reported download start, no need to do that again
|
||||
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self._log,
|
||||
report_start=True if verbose else None)
|
||||
self._driver.download_object(obj, temp_local_path, callback=cb)
|
||||
download_reported = bool(cb.last_reported)
|
||||
dl_total_mb = cb.current_status_mb
|
||||
@@ -2686,15 +2697,28 @@ class StorageHelper(object):
|
||||
raise Exception('Failed renaming partial file, downloaded file exists and a 0-sized file')
|
||||
|
||||
# report download if we are on the second chunk
|
||||
if verbose or download_reported:
|
||||
if cb:
|
||||
cb.close(
|
||||
report_completed=True,
|
||||
report_summary=verbose or download_reported,
|
||||
report_prefix="Downloaded",
|
||||
report_suffix="from {} , saved to {}".format(remote_path, local_path)
|
||||
)
|
||||
elif verbose or download_reported:
|
||||
self._log.info(
|
||||
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
|
||||
"Downloaded {:.2f} MB successfully from {} , saved to {}".format(
|
||||
dl_total_mb, remote_path, local_path)
|
||||
)
|
||||
return local_path
|
||||
except DownloadError:
|
||||
if cb:
|
||||
cb.close()
|
||||
raise
|
||||
except Exception as e:
|
||||
if cb:
|
||||
cb.close()
|
||||
self._log.error("Could not download {} , err: {} ".format(remote_path, e))
|
||||
if delete_on_failure:
|
||||
if delete_on_failure and temp_local_path:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
os.remove(temp_local_path)
|
||||
@@ -2880,7 +2904,9 @@ class StorageHelper(object):
|
||||
|
||||
def _do_async_upload(self, data):
|
||||
assert isinstance(data, self._UploadData)
|
||||
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, verbose=True, retries=data.retries, return_canonized=data.return_canonized)
|
||||
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path,
|
||||
extra=data.extra, cb=data.callback, verbose=True,
|
||||
retries=data.retries, return_canonized=data.return_canonized)
|
||||
|
||||
def _upload_from_file(self, local_path, dest_path, extra=None):
|
||||
if not hasattr(self._driver, 'upload_object'):
|
||||
@@ -2897,9 +2923,12 @@ class StorageHelper(object):
|
||||
object_name=object_name,
|
||||
callback=cb,
|
||||
extra=extra)
|
||||
if cb:
|
||||
cb.close()
|
||||
return res
|
||||
|
||||
def _do_upload(self, src_path, dest_path, canonized_dest_path, extra=None, cb=None, verbose=False, retries=1, return_canonized=False):
|
||||
def _do_upload(self, src_path, dest_path, canonized_dest_path,
|
||||
extra=None, cb=None, verbose=False, retries=1, return_canonized=False):
|
||||
object_name = self._normalize_object_name(canonized_dest_path)
|
||||
if cb:
|
||||
try:
|
||||
|
||||
307
clearml/task.py
307
clearml/task.py
@@ -1,9 +1,7 @@
|
||||
import atexit
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -42,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
|
||||
from .backend_api.services import tasks, projects, events
|
||||
from .backend_api.session.session import (
|
||||
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
|
||||
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError
|
||||
from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
|
||||
ENV_OFFLINE_MODE, MissingConfigError)
|
||||
from .backend_interface.metrics import Metrics
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
from .backend_interface.base import InterfaceBase
|
||||
@@ -99,13 +98,17 @@ from .utilities.proxy_object import (
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
from .utilities.lowlevel.threads import get_current_thread_id
|
||||
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
|
||||
create_torch_distributed_anchor
|
||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||
from .utilities.process.exit_hooks import ExitHooks
|
||||
from .utilities.matching import matches_any_wildcard
|
||||
from .utilities.parallel import FutureTaskCaller
|
||||
from .utilities.networking import get_private_ip
|
||||
# noinspection PyProtectedMember
|
||||
from .backend_interface.task.args import _Arguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas
|
||||
import numpy
|
||||
@@ -487,8 +490,6 @@ class Task(_Task):
|
||||
# unregister signal hooks, they cause subprocess to hang
|
||||
# noinspection PyProtectedMember
|
||||
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
|
||||
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
|
||||
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
|
||||
|
||||
# start all reporting threads
|
||||
BackgroundMonitor.start_all(task=cls.__main_task)
|
||||
@@ -530,10 +531,16 @@ class Task(_Task):
|
||||
is_deferred = False
|
||||
try:
|
||||
if not running_remotely():
|
||||
# check remote status
|
||||
_local_rank = get_torch_local_rank()
|
||||
if _local_rank is not None and _local_rank > 0:
|
||||
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
|
||||
|
||||
# only allow if running locally and creating the first Task
|
||||
# otherwise we ignore and perform in order
|
||||
if ENV_DEFERRED_TASK_INIT.get():
|
||||
deferred_init = True
|
||||
|
||||
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
|
||||
def completed_cb(x):
|
||||
Task.__main_task = x
|
||||
@@ -574,6 +581,11 @@ class Task(_Task):
|
||||
not auto_connect_frameworks.get('detect_repository', True)) else True,
|
||||
auto_connect_streams=auto_connect_streams,
|
||||
)
|
||||
# check if we are local rank 0 (local master),
|
||||
# create an anchor with task ID for the other processes
|
||||
if _local_rank == 0:
|
||||
create_torch_distributed_anchor(task_id=task.id)
|
||||
|
||||
except MissingConfigError as e:
|
||||
if not ENV_IGNORE_MISSING_CONFIG.get():
|
||||
raise
|
||||
@@ -636,7 +648,11 @@ class Task(_Task):
|
||||
# register at exist only on the real (none deferred) Task
|
||||
if not is_deferred:
|
||||
# register the main task for at exit hooks (there should only be one)
|
||||
# noinspection PyProtectedMember
|
||||
task.__register_at_exit(task._at_exit)
|
||||
# noinspection PyProtectedMember
|
||||
if cls.__exit_hook:
|
||||
cls.__exit_hook.register_signal_and_exception_hooks()
|
||||
|
||||
# always patch OS forking because of ProcessPool and the alike
|
||||
PatchOsFork.patch_fork(task)
|
||||
@@ -1552,6 +1568,69 @@ class Task(_Task):
|
||||
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
def set_packages(self, packages):
|
||||
# type: (Union[str, Path, Sequence[str]]) -> ()
|
||||
"""
|
||||
Manually specify a list of required packages or a local requirements.txt file.
|
||||
|
||||
When running remotely this call is ignored
|
||||
|
||||
:param packages: The list of packages or the path to the requirements.txt file.
|
||||
|
||||
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt" or ""
|
||||
Use an empty string (packages="") to clear the requirements section (remote execution will use
|
||||
requirements.txt from the git repository if the file exists)
|
||||
"""
|
||||
if running_remotely() or packages is None:
|
||||
return
|
||||
self._wait_for_repo_detection(timeout=300.)
|
||||
|
||||
if packages and isinstance(packages, (str, Path)) and Path(packages).is_file():
|
||||
with open(Path(packages).as_posix(), "rt") as f:
|
||||
# noinspection PyProtectedMember
|
||||
self._update_requirements([line.strip() for line in f.readlines()])
|
||||
return
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
self._update_requirements(packages or "")
|
||||
|
||||
def set_repo(self, repo=None, branch=None, commit=None):
|
||||
# type: (Optional[str], Optional[str], Optional[str]) -> ()
|
||||
"""
|
||||
Specify a repository to attach to the function.
|
||||
Allow users to execute the task inside the specified repository, enabling them to load modules/script
|
||||
from the repository. Notice the execution work directory will be the repository root folder.
|
||||
Supports both git repo url link, and local repository path (automatically converted into the remote
|
||||
git/commit as is currently checkout).
|
||||
Example remote url: "https://github.com/user/repo.git".
|
||||
Example local repo copy: "./repo" -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy.
|
||||
When executing remotely, this call will not override the repository data (it is ignored)
|
||||
|
||||
:param repo: Optional, remote URL for the repository to use, OR path to local copy of the git repository.
|
||||
Use an empty string to clear the repo.
|
||||
Example: "https://github.com/allegroai/clearml.git" or "~/project/repo" or ""
|
||||
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used).
|
||||
Use an empty string to clear the branch.
|
||||
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used).
|
||||
Use an empty string to clear the commit.
|
||||
"""
|
||||
if running_remotely():
|
||||
return
|
||||
self._wait_for_repo_detection(timeout=300.)
|
||||
with self._edit_lock:
|
||||
self.reload()
|
||||
if repo is not None:
|
||||
# we cannot have None on the value itself
|
||||
self.data.script.repository = repo or ""
|
||||
if branch is not None:
|
||||
# we cannot have None on the value itself
|
||||
self.data.script.branch = branch or ""
|
||||
if commit is not None:
|
||||
# we cannot have None on the value itself
|
||||
self.data.script.version_num = commit or ""
|
||||
self._edit(script=self.data.script)
|
||||
|
||||
def connect_configuration(self, configuration, name=None, description=None):
|
||||
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str]
|
||||
"""
|
||||
@@ -2015,6 +2094,8 @@ class Task(_Task):
|
||||
# unregister atexit callbacks and signal hooks, if we are the main task
|
||||
if is_main:
|
||||
self.__register_at_exit(None)
|
||||
self._remove_signal_hooks()
|
||||
self._remove_exception_hooks()
|
||||
if not is_sub_process:
|
||||
# make sure we enable multiple Task.init callas with reporting sub-processes
|
||||
BackgroundMonitor.clear_main_process(self)
|
||||
@@ -2715,41 +2796,6 @@ class Task(_Task):
|
||||
docker_setup_bash_script=docker_setup_bash_script
|
||||
)
|
||||
|
||||
def set_packages(self, packages):
|
||||
# type: (Union[str, Sequence[str]]) -> ()
|
||||
"""
|
||||
Manually specify a list of required packages or a local requirements.txt file.
|
||||
When running remotely the call is ignored
|
||||
|
||||
:param packages: The list of packages or the path to the requirements.txt file.
|
||||
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt"
|
||||
"""
|
||||
if running_remotely():
|
||||
return
|
||||
super(Task, self).set_packages(packages)
|
||||
|
||||
def set_repo(self, repo, branch=None, commit=None):
|
||||
# type: (str, Optional[str], Optional[str]) -> ()
|
||||
"""
|
||||
Specify a repository to attach to the function.
|
||||
Allow users to execute the task inside the specified repository, enabling them to load modules/script
|
||||
from the repository. Notice the execution work directory will be the repository root folder.
|
||||
Supports both git repo url link, and local repository path (automatically converted into the remote
|
||||
git/commit as is currently checkout).
|
||||
Example remote url: 'https://github.com/user/repo.git'.
|
||||
Example local repo copy: './repo' -> will automatically store the remote
|
||||
repo url and commit ID based on the locally cloned copy.
|
||||
When executing remotely, this call will not override the repository data (it is ignored)
|
||||
|
||||
:param repo: Remote URL for the repository to use, OR path to local copy of the git repository
|
||||
Example: 'https://github.com/allegroai/clearml.git' or '~/project/repo'
|
||||
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
|
||||
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
|
||||
"""
|
||||
if running_remotely():
|
||||
return
|
||||
super(Task, self).set_repo(repo, branch=branch, commit=commit)
|
||||
|
||||
def set_resource_monitor_iteration_timeout(self, seconds_from_start=1800):
|
||||
# type: (float) -> bool
|
||||
"""
|
||||
@@ -4192,172 +4238,21 @@ class Task(_Task):
|
||||
if not is_sub_process and BackgroundMonitor.is_subprocess_enabled():
|
||||
BackgroundMonitor.wait_for_sub_process(self)
|
||||
|
||||
# we are done
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
|
||||
class ExitHooks(object):
|
||||
_orig_exit = None
|
||||
_orig_exc_handler = None
|
||||
remote_user_aborted = False
|
||||
def _remove_exception_hooks(cls):
|
||||
if cls.__exit_hook:
|
||||
cls.__exit_hook.remove_exception_hooks()
|
||||
|
||||
def __init__(self, callback):
|
||||
self.exit_code = None
|
||||
self.exception = None
|
||||
self.signal = None
|
||||
self._exit_callback = callback
|
||||
self._org_handlers = {}
|
||||
self._signal_recursion_protection_flag = False
|
||||
self._except_recursion_protection_flag = False
|
||||
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
|
||||
|
||||
def update_callback(self, callback):
|
||||
if self._exit_callback and not six.PY2:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
atexit.unregister(self._exit_callback)
|
||||
except Exception:
|
||||
pass
|
||||
self._exit_callback = callback
|
||||
if callback:
|
||||
self.hook()
|
||||
else:
|
||||
# un register int hook
|
||||
if self._orig_exc_handler:
|
||||
sys.excepthook = self._orig_exc_handler
|
||||
self._orig_exc_handler = None
|
||||
for h in self._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(h, self._org_handlers[h])
|
||||
except Exception:
|
||||
pass
|
||||
self._org_handlers = {}
|
||||
|
||||
def hook(self):
|
||||
if self._orig_exit is None:
|
||||
self._orig_exit = sys.exit
|
||||
sys.exit = self.exit
|
||||
|
||||
if self._orig_exc_handler is None:
|
||||
self._orig_exc_handler = sys.excepthook
|
||||
sys.excepthook = self.exc_handler
|
||||
|
||||
if self._exit_callback:
|
||||
atexit.register(self._exit_callback)
|
||||
|
||||
# TODO: check if sub-process hooks are safe enough, for the time being allow it
|
||||
if not self._org_handlers: # ## and not Task._Task__is_subprocess():
|
||||
if sys.platform == 'win32':
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE]
|
||||
else:
|
||||
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
|
||||
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
|
||||
for c in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[c] = signal.getsignal(c)
|
||||
signal.signal(c, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def exit(self, code=0):
|
||||
self.exit_code = code
|
||||
self._orig_exit(code)
|
||||
|
||||
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
|
||||
if self._except_recursion_protection_flag:
|
||||
# noinspection PyArgumentList
|
||||
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||
|
||||
self._except_recursion_protection_flag = True
|
||||
self.exception = value
|
||||
|
||||
try:
|
||||
# remove us from import errors
|
||||
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
|
||||
prev = cur = traceback
|
||||
while cur is not None:
|
||||
tb_next = cur.tb_next
|
||||
# if this is the import frame, we should remove it
|
||||
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
|
||||
# remove this frame by connecting the previous one to the next one
|
||||
prev.tb_next = tb_next
|
||||
cur.tb_next = None
|
||||
del cur
|
||||
cur = prev
|
||||
|
||||
prev = cur
|
||||
cur = tb_next
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
if self._orig_exc_handler:
|
||||
# noinspection PyArgumentList
|
||||
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
|
||||
else:
|
||||
# noinspection PyNoneFunctionAssignment, PyArgumentList
|
||||
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||
self._except_recursion_protection_flag = False
|
||||
|
||||
return ret
|
||||
|
||||
def signal_handler(self, sig, frame):
|
||||
self.signal = sig
|
||||
|
||||
org_handler = self._org_handlers.get(sig)
|
||||
signal.signal(sig, org_handler or signal.SIG_DFL)
|
||||
|
||||
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
|
||||
if sig == signal.SIGINT:
|
||||
# return original handler result
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
|
||||
if self._signal_recursion_protection_flag:
|
||||
# call original
|
||||
os.kill(os.getpid(), sig)
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
|
||||
self._signal_recursion_protection_flag = True
|
||||
|
||||
# call exit callback
|
||||
if self._exit_callback:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._exit_callback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# remove stdout logger, just in case
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyProtectedMember
|
||||
Logger._remove_std_logger()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
os.kill(os.getpid(), sig)
|
||||
|
||||
self._signal_recursion_protection_flag = False
|
||||
# return handler result
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
|
||||
# we only remove the signals since this will hang subprocesses
|
||||
if only_remove_signal_and_exception_hooks:
|
||||
if not cls.__exit_hook:
|
||||
return
|
||||
if cls.__exit_hook._orig_exc_handler:
|
||||
sys.excepthook = cls.__exit_hook._orig_exc_handler
|
||||
cls.__exit_hook._orig_exc_handler = None
|
||||
for s in cls.__exit_hook._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(s, cls.__exit_hook._org_handlers[s])
|
||||
except Exception:
|
||||
pass
|
||||
cls.__exit_hook._org_handlers = {}
|
||||
return
|
||||
@classmethod
|
||||
def _remove_signal_hooks(cls):
|
||||
if cls.__exit_hook:
|
||||
cls.__exit_hook.remove_signal_hooks()
|
||||
|
||||
@classmethod
|
||||
def __register_at_exit(cls, exit_callback):
|
||||
if cls.__exit_hook is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@@ -4368,12 +4263,6 @@ class Task(_Task):
|
||||
else:
|
||||
cls.__exit_hook.update_callback(exit_callback)
|
||||
|
||||
def _remove_at_exit_callbacks(self):
|
||||
self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
|
||||
# noinspection PyProtectedMember
|
||||
atexit.unregister(self.__exit_hook._exit_callback)
|
||||
self._at_exit_called = True
|
||||
|
||||
@classmethod
|
||||
def __get_task(
|
||||
cls,
|
||||
|
||||
336
clearml/utilities/distutils_version.py
Normal file
336
clearml/utilities/distutils_version.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#
|
||||
# distutils/version.py
|
||||
#
|
||||
# Implements multiple version numbering conventions for the
|
||||
# Python Module Distribution Utilities.
|
||||
#
|
||||
# $Id$
|
||||
#
|
||||
|
||||
"""Provides classes to represent module version numbers (one class for
|
||||
each style of version numbering). There are currently two such classes
|
||||
implemented: StrictVersion and LooseVersion.
|
||||
|
||||
Every version number class implements the following interface:
|
||||
* the 'parse' method takes a string and parses it to some internal
|
||||
representation; if the string is an invalid version number,
|
||||
'parse' raises a ValueError exception
|
||||
* the class constructor takes an optional string argument which,
|
||||
if supplied, is passed to 'parse'
|
||||
* __str__ reconstructs the string that was passed to 'parse' (or
|
||||
an equivalent string -- ie. one that will generate an equivalent
|
||||
version number instance)
|
||||
* __repr__ generates Python code to recreate the version number instance
|
||||
* _cmp compares the current instance with either another instance
|
||||
of the same class or a string (which will be parsed to an instance
|
||||
of the same class, thus must follow the same rules)
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
class Version:
|
||||
"""Abstract base class for version numbering classes. Just provides
|
||||
constructor (__init__) and reproducer (__repr__), because those
|
||||
seem to be the same for all version numbering classes; and route
|
||||
rich comparisons to _cmp.
|
||||
"""
|
||||
|
||||
def __init__(self, vstring=None):
|
||||
if vstring:
|
||||
self.parse(vstring)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s ('%s')" % (self.__class__.__name__, str(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
c = self._cmp(other)
|
||||
if c is NotImplemented:
|
||||
return c
|
||||
return c == 0
|
||||
|
||||
def __lt__(self, other):
|
||||
c = self._cmp(other)
|
||||
if c is NotImplemented:
|
||||
return c
|
||||
return c < 0
|
||||
|
||||
def __le__(self, other):
|
||||
c = self._cmp(other)
|
||||
if c is NotImplemented:
|
||||
return c
|
||||
return c <= 0
|
||||
|
||||
def __gt__(self, other):
|
||||
c = self._cmp(other)
|
||||
if c is NotImplemented:
|
||||
return c
|
||||
return c > 0
|
||||
|
||||
def __ge__(self, other):
|
||||
c = self._cmp(other)
|
||||
if c is NotImplemented:
|
||||
return c
|
||||
return c >= 0
|
||||
|
||||
|
||||
# Interface for version-number classes -- must be implemented
|
||||
# by the following classes (the concrete ones -- Version should
|
||||
# be treated as an abstract class).
|
||||
# __init__ (string) - create and take same action as 'parse'
|
||||
# (string parameter is optional)
|
||||
# parse (string) - convert a string representation to whatever
|
||||
# internal representation is appropriate for
|
||||
# this style of version numbering
|
||||
# __str__ (self) - convert back to a string; should be very similar
|
||||
# (if not identical to) the string supplied to parse
|
||||
# __repr__ (self) - generate Python code to recreate
|
||||
# the instance
|
||||
# _cmp (self, other) - compare two version numbers ('other' may
|
||||
# be an unparsed version string, or another
|
||||
# instance of your version class)
|
||||
|
||||
|
||||
class StrictVersion(Version):
|
||||
|
||||
"""Version numbering for anal retentives and software idealists.
|
||||
Implements the standard interface for version number classes as
|
||||
described above. A version number consists of two or three
|
||||
dot-separated numeric components, with an optional "pre-release" tag
|
||||
on the end. The pre-release tag consists of the letter 'a' or 'b'
|
||||
followed by a number. If the numeric components of two version
|
||||
numbers are equal, then one with a pre-release tag will always
|
||||
be deemed earlier (lesser) than one without.
|
||||
|
||||
The following are valid version numbers (shown in the order that
|
||||
would be obtained by sorting according to the supplied cmp function):
|
||||
|
||||
0.4 0.4.0 (these two are equivalent)
|
||||
0.4.1
|
||||
0.5a1
|
||||
0.5b3
|
||||
0.5
|
||||
0.9.6
|
||||
1.0
|
||||
1.0.4a3
|
||||
1.0.4b1
|
||||
1.0.4
|
||||
|
||||
The following are examples of invalid version numbers:
|
||||
|
||||
1
|
||||
2.7.2.2
|
||||
1.3.a4
|
||||
1.3pl1
|
||||
1.3c4
|
||||
|
||||
The rationale for this version numbering system will be explained
|
||||
in the distutils documentation.
|
||||
"""
|
||||
|
||||
version_re = re.compile(r"^(\d+) \. (\d+) (\. (\d+))? ([ab](\d+))?$", re.VERBOSE | re.ASCII)
|
||||
|
||||
def parse(self, vstring):
|
||||
match = self.version_re.match(vstring)
|
||||
if not match:
|
||||
raise ValueError("invalid version number '%s'" % vstring)
|
||||
|
||||
(major, minor, patch, prerelease, prerelease_num) = match.group(1, 2, 4, 5, 6)
|
||||
|
||||
if patch:
|
||||
self.version = tuple(map(int, [major, minor, patch]))
|
||||
else:
|
||||
self.version = tuple(map(int, [major, minor])) + (0,)
|
||||
|
||||
if prerelease:
|
||||
self.prerelease = (prerelease[0], int(prerelease_num))
|
||||
else:
|
||||
self.prerelease = None
|
||||
|
||||
def __str__(self):
|
||||
|
||||
if self.version[2] == 0:
|
||||
vstring = ".".join(map(str, self.version[0:2]))
|
||||
else:
|
||||
vstring = ".".join(map(str, self.version))
|
||||
|
||||
if self.prerelease:
|
||||
vstring = vstring + self.prerelease[0] + str(self.prerelease[1])
|
||||
|
||||
return vstring
|
||||
|
||||
def _cmp(self, other):
|
||||
if isinstance(other, str):
|
||||
other = StrictVersion(other)
|
||||
|
||||
if self.version != other.version:
|
||||
# numeric versions don't match
|
||||
# prerelease stuff doesn't matter
|
||||
if self.version < other.version:
|
||||
return -1
|
||||
else:
|
||||
return 1
|
||||
|
||||
# have to compare prerelease
|
||||
# case 1: neither has prerelease; they're equal
|
||||
# case 2: self has prerelease, other doesn't; other is greater
|
||||
# case 3: self doesn't have prerelease, other does: self is greater
|
||||
# case 4: both have prerelease: must compare them!
|
||||
|
||||
if not self.prerelease and not other.prerelease:
|
||||
return 0
|
||||
elif self.prerelease and not other.prerelease:
|
||||
return -1
|
||||
elif not self.prerelease and other.prerelease:
|
||||
return 1
|
||||
elif self.prerelease and other.prerelease:
|
||||
if self.prerelease == other.prerelease:
|
||||
return 0
|
||||
elif self.prerelease < other.prerelease:
|
||||
return -1
|
||||
else:
|
||||
return 1
|
||||
else:
|
||||
assert False, "never get here"
|
||||
|
||||
|
||||
# end class StrictVersion
|
||||
|
||||
|
||||
# The rules according to Greg Stein:
|
||||
# 1) a version number has 1 or more numbers separated by a period or by
|
||||
# sequences of letters. If only periods, then these are compared
|
||||
# left-to-right to determine an ordering.
|
||||
# 2) sequences of letters are part of the tuple for comparison and are
|
||||
# compared lexicographically
|
||||
# 3) recognize the numeric components may have leading zeroes
|
||||
#
|
||||
# The LooseVersion class below implements these rules: a version number
|
||||
# string is split up into a tuple of integer and string components, and
|
||||
# comparison is a simple tuple comparison. This means that version
|
||||
# numbers behave in a predictable and obvious way, but a way that might
|
||||
# not necessarily be how people *want* version numbers to behave. There
|
||||
# wouldn't be a problem if people could stick to purely numeric version
|
||||
# numbers: just split on period and compare the numbers as tuples.
|
||||
# However, people insist on putting letters into their version numbers;
|
||||
# the most common purpose seems to be:
|
||||
# - indicating a "pre-release" version
|
||||
# ('alpha', 'beta', 'a', 'b', 'pre', 'p')
|
||||
# - indicating a post-release patch ('p', 'pl', 'patch')
|
||||
# but of course this can't cover all version number schemes, and there's
|
||||
# no way to know what a programmer means without asking him.
|
||||
#
|
||||
# The problem is what to do with letters (and other non-numeric
|
||||
# characters) in a version number. The current implementation does the
|
||||
# obvious and predictable thing: keep them as strings and compare
|
||||
# lexically within a tuple comparison. This has the desired effect if
|
||||
# an appended letter sequence implies something "post-release":
|
||||
# eg. "0.99" < "0.99pl14" < "1.0", and "5.001" < "5.001m" < "5.002".
|
||||
#
|
||||
# However, if letters in a version number imply a pre-release version,
|
||||
# the "obvious" thing isn't correct. Eg. you would expect that
|
||||
# "1.5.1" < "1.5.2a2" < "1.5.2", but under the tuple/lexical comparison
|
||||
# implemented here, this just isn't so.
|
||||
#
|
||||
# Two possible solutions come to mind. The first is to tie the
|
||||
# comparison algorithm to a particular set of semantic rules, as has
|
||||
# been done in the StrictVersion class above. This works great as long
|
||||
# as everyone can go along with bondage and discipline. Hopefully a
|
||||
# (large) subset of Python module programmers will agree that the
|
||||
# particular flavour of bondage and discipline provided by StrictVersion
|
||||
# provides enough benefit to be worth using, and will submit their
|
||||
# version numbering scheme to its domination. The free-thinking
|
||||
# anarchists in the lot will never give in, though, and something needs
|
||||
# to be done to accommodate them.
|
||||
#
|
||||
# Perhaps a "moderately strict" version class could be implemented that
|
||||
# lets almost anything slide (syntactically), and makes some heuristic
|
||||
# assumptions about non-digits in version number strings. This could
|
||||
# sink into special-case-hell, though; if I was as talented and
|
||||
# idiosyncratic as Larry Wall, I'd go ahead and implement a class that
|
||||
# somehow knows that "1.2.1" < "1.2.2a2" < "1.2.2" < "1.2.2pl3", and is
|
||||
# just as happy dealing with things like "2g6" and "1.13++". I don't
|
||||
# think I'm smart enough to do it right though.
|
||||
#
|
||||
# In any case, I've coded the test suite for this module (see
|
||||
# ../test/test_version.py) specifically to fail on things like comparing
|
||||
# "1.2a2" and "1.2". That's not because the *code* is doing anything
|
||||
# wrong, it's because the simple, obvious design doesn't match my
|
||||
# complicated, hairy expectations for real-world version numbers. It
|
||||
# would be a snap to fix the test suite to say, "Yep, LooseVersion does
|
||||
# the Right Thing" (ie. the code matches the conception). But I'd rather
|
||||
# have a conception that matches common notions about version numbers.
|
||||
|
||||
|
||||
class LooseVersion(Version):
|
||||
|
||||
"""Version numbering for anarchists and software realists.
|
||||
Implements the standard interface for version number classes as
|
||||
described above. A version number consists of a series of numbers,
|
||||
separated by either periods or strings of letters. When comparing
|
||||
version numbers, the numeric components will be compared
|
||||
numerically, and the alphabetic components lexically. The following
|
||||
are all valid version numbers, in no particular order:
|
||||
|
||||
1.5.1
|
||||
1.5.2b2
|
||||
161
|
||||
3.10a
|
||||
8.02
|
||||
3.4j
|
||||
1996.07.12
|
||||
3.2.pl0
|
||||
3.1.1.6
|
||||
2g6
|
||||
11g
|
||||
0.960923
|
||||
2.2beta29
|
||||
1.13++
|
||||
5.5.kw
|
||||
2.0b1pl0
|
||||
|
||||
In fact, there is no such thing as an invalid version number under
|
||||
this scheme; the rules for comparison are simple and predictable,
|
||||
but may not always give the results you want (for some definition
|
||||
of "want").
|
||||
"""
|
||||
|
||||
component_re = re.compile(r"(\d+ | [a-z]+ | \.)", re.VERBOSE)
|
||||
|
||||
def __init__(self, vstring=None):
|
||||
if vstring:
|
||||
self.parse(vstring)
|
||||
|
||||
def parse(self, vstring):
|
||||
# I've given up on thinking I can reconstruct the version string
|
||||
# from the parsed tuple -- so I just store the string here for
|
||||
# use by __str__
|
||||
self.vstring = vstring
|
||||
components = [x for x in self.component_re.split(vstring) if x and x != "."]
|
||||
for i, obj in enumerate(components):
|
||||
try:
|
||||
components[i] = int(obj)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.version = components
|
||||
|
||||
def __str__(self):
|
||||
return self.vstring
|
||||
|
||||
def __repr__(self):
|
||||
return "LooseVersion ('%s')" % str(self)
|
||||
|
||||
def _cmp(self, other):
|
||||
if isinstance(other, str):
|
||||
other = LooseVersion(other)
|
||||
|
||||
if self.version == other.version:
|
||||
return 0
|
||||
if self.version < other.version:
|
||||
return -1
|
||||
if self.version > other.version:
|
||||
return 1
|
||||
|
||||
|
||||
# end class LooseVersion
|
||||
96
clearml/utilities/lowlevel/distributed.py
Normal file
96
clearml/utilities/lowlevel/distributed.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
from logging import getLogger
|
||||
from time import sleep, time
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def get_torch_local_rank():
|
||||
"""
|
||||
return the local rank of the process, notice local rank 0 does not mean global rank 0
|
||||
return None if no torch distributed is running
|
||||
"""
|
||||
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return int(os.environ.get("LOCAL_RANK"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_torch_distributed_anchor(task_id):
|
||||
"""
|
||||
This will create a temporary file to pass the Task ID created by local_rank 0 of
|
||||
if None local rank 0 is calling this file, it
|
||||
|
||||
Only call when running locally (i.e. without an agent),
|
||||
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||
"""
|
||||
local_file_name = ".clearml_torch_distributed_id"
|
||||
|
||||
if get_torch_local_rank() != 0:
|
||||
return
|
||||
|
||||
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||
|
||||
if not torch_dist_path:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
torch_dist_path = Path(torch_dist_path).parent.parent.parent
|
||||
# create the file
|
||||
with open(torch_dist_path / local_file_name, "wt") as f:
|
||||
f.write(str(task_id)+"\n")
|
||||
except Exception:
|
||||
# we failed for some reason?
|
||||
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
|
||||
|
||||
|
||||
def get_torch_distributed_anchor_task_id(timeout=None):
|
||||
"""
|
||||
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
|
||||
|
||||
Only call when running locally (i.e. without an agent),
|
||||
if running remotely there is no need to pass Task ID, it will be passed externally
|
||||
|
||||
:return Task ID of the local task to report to
|
||||
"""
|
||||
|
||||
# check that we are not local rank 0
|
||||
_local_rank = get_torch_local_rank()
|
||||
if not _local_rank:
|
||||
return
|
||||
|
||||
local_file_name = ".clearml_torch_distributed_id"
|
||||
|
||||
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
|
||||
if not torch_dist_path:
|
||||
return
|
||||
|
||||
task_id = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
|
||||
|
||||
tic = time()
|
||||
# wait until disturbed file exists
|
||||
while not torch_dist_path.is_file():
|
||||
# if we found nothing, return None
|
||||
if timeout is not None and time() - tic > timeout:
|
||||
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
|
||||
return None
|
||||
# wait
|
||||
sleep(0.25)
|
||||
|
||||
# create the file
|
||||
with open(torch_dist_path, "rt") as f:
|
||||
task_id = f.read().strip(" \n")
|
||||
except Exception:
|
||||
# we failed for some reason?
|
||||
pass
|
||||
|
||||
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
|
||||
return task_id
|
||||
@@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
import itertools
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from clearml.utilities.distutils_version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
|
||||
187
clearml/utilities/process/exit_hooks.py
Normal file
187
clearml/utilities/process/exit_hooks.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import atexit
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from ...logger import Logger
|
||||
|
||||
|
||||
class ExitHooks(object):
|
||||
_orig_exit = None
|
||||
_orig_exc_handler = None
|
||||
remote_user_aborted = False
|
||||
|
||||
def __init__(self, callback):
|
||||
self.exit_code = None
|
||||
self.exception = None
|
||||
self.signal = None
|
||||
self._exit_callback = callback
|
||||
self._org_handlers = {}
|
||||
self._signal_recursion_protection_flag = False
|
||||
self._except_recursion_protection_flag = False
|
||||
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
|
||||
|
||||
def update_callback(self, callback):
|
||||
if self._exit_callback and not six.PY2:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
atexit.unregister(self._exit_callback)
|
||||
except Exception:
|
||||
pass
|
||||
self._exit_callback = callback
|
||||
if callback:
|
||||
self.hook()
|
||||
else:
|
||||
# un register int hook
|
||||
if self._orig_exc_handler:
|
||||
sys.excepthook = self._orig_exc_handler
|
||||
self._orig_exc_handler = None
|
||||
for h in self._org_handlers:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(h, self._org_handlers[h])
|
||||
except Exception:
|
||||
pass
|
||||
self._org_handlers = {}
|
||||
|
||||
def hook(self):
|
||||
if self._orig_exit is None:
|
||||
self._orig_exit = sys.exit
|
||||
sys.exit = self.exit
|
||||
|
||||
if self._exit_callback:
|
||||
atexit.register(self._exit_callback)
|
||||
|
||||
def register_signal_and_exception_hooks(self):
|
||||
if self._orig_exc_handler is None:
|
||||
self._orig_exc_handler = sys.excepthook
|
||||
|
||||
sys.excepthook = self.exc_handler
|
||||
|
||||
if not self._org_handlers:
|
||||
if sys.platform == "win32":
|
||||
catch_signals = [
|
||||
signal.SIGINT,
|
||||
signal.SIGTERM,
|
||||
signal.SIGSEGV,
|
||||
signal.SIGABRT,
|
||||
signal.SIGILL,
|
||||
signal.SIGFPE,
|
||||
]
|
||||
else:
|
||||
catch_signals = [
|
||||
signal.SIGINT,
|
||||
signal.SIGTERM,
|
||||
signal.SIGSEGV,
|
||||
signal.SIGABRT,
|
||||
signal.SIGILL,
|
||||
signal.SIGFPE,
|
||||
signal.SIGQUIT,
|
||||
]
|
||||
for c in catch_signals:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._org_handlers[c] = signal.getsignal(c)
|
||||
signal.signal(c, self.signal_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def remove_signal_hooks(self):
|
||||
for org_handler_k, org_handler_v in self._org_handlers.items():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
signal.signal(org_handler_k, org_handler_v)
|
||||
except Exception:
|
||||
pass
|
||||
self._org_handlers = {}
|
||||
|
||||
def remove_exception_hooks(self):
|
||||
if self._orig_exc_handler:
|
||||
sys.excepthook = self._orig_exc_handler
|
||||
self._orig_exc_handler = None
|
||||
|
||||
def exit(self, code=0):
|
||||
self.exit_code = code
|
||||
self._orig_exit(code)
|
||||
|
||||
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
|
||||
if self._except_recursion_protection_flag or not self._orig_exc_handler:
|
||||
# noinspection PyArgumentList
|
||||
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||
|
||||
self._except_recursion_protection_flag = True
|
||||
self.exception = value
|
||||
|
||||
try:
|
||||
# remove us from import errors
|
||||
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
|
||||
prev = cur = traceback
|
||||
while cur is not None:
|
||||
tb_next = cur.tb_next
|
||||
# if this is the import frame, we should remove it
|
||||
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
|
||||
# remove this frame by connecting the previous one to the next one
|
||||
prev.tb_next = tb_next
|
||||
cur.tb_next = None
|
||||
del cur
|
||||
cur = prev
|
||||
|
||||
prev = cur
|
||||
cur = tb_next
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
if self._orig_exc_handler:
|
||||
# noinspection PyArgumentList
|
||||
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
|
||||
else:
|
||||
# noinspection PyNoneFunctionAssignment, PyArgumentList
|
||||
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
|
||||
self._except_recursion_protection_flag = False
|
||||
|
||||
return ret
|
||||
|
||||
def signal_handler(self, sig, frame):
|
||||
org_handler = self._org_handlers.get(sig)
|
||||
if not org_handler:
|
||||
return signal.SIG_DFL
|
||||
|
||||
self.signal = sig
|
||||
signal.signal(sig, org_handler or signal.SIG_DFL)
|
||||
|
||||
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
|
||||
if sig == signal.SIGINT:
|
||||
# return original handler result
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
|
||||
if self._signal_recursion_protection_flag:
|
||||
# call original
|
||||
os.kill(os.getpid(), sig)
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
|
||||
self._signal_recursion_protection_flag = True
|
||||
|
||||
# call exit callback
|
||||
if self._exit_callback:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self._exit_callback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# remove stdout logger, just in case
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# noinspection PyProtectedMember
|
||||
Logger._remove_std_logger()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
os.kill(os.getpid(), sig)
|
||||
|
||||
self._signal_recursion_protection_flag = False
|
||||
# return handler result
|
||||
return org_handler if not callable(org_handler) else org_handler(sig, frame)
|
||||
@@ -541,6 +541,15 @@ class BackgroundMonitor(object):
|
||||
self._event.set()
|
||||
|
||||
if isinstance(self._thread, Thread):
|
||||
# should we wait for the thread to finish
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# there is a race here, and if someone else closes the
|
||||
# thread it can become True/None and we will fail, it is fine
|
||||
self._thread.join()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._get_instances().remove(self)
|
||||
except ValueError:
|
||||
@@ -669,21 +678,23 @@ class BackgroundMonitor(object):
|
||||
@classmethod
|
||||
def _background_process_start(cls, task_obj_id, event_start=None, parent_pid=None):
|
||||
# type: (int, Optional[SafeEvent], Optional[int]) -> None
|
||||
# noinspection PyProtectedMember
|
||||
is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace())
|
||||
# make sure we update the pid to our own
|
||||
cls._main_process = os.getpid()
|
||||
cls._main_process_proc_obj = psutil.Process(cls._main_process)
|
||||
# restore original signal, this will prevent any deadlocks
|
||||
# Do not change the exception we need to catch base exception as well
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from ... import Task
|
||||
# make sure we do not call Task.current_task() it will create a Task object for us on a subprocess!
|
||||
if Task._Task__current_task and Task._Task__current_task._Task__exit_hook: # noqa
|
||||
Task._Task__current_task._Task__exit_hook.register_signal_and_exception_hooks() # noqa
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
if Task._has_current_task_obj():
|
||||
# noinspection PyProtectedMember
|
||||
Task.current_task()._remove_at_exit_callbacks()
|
||||
from ...binding.environ_bind import PatchOsFork
|
||||
PatchOsFork.unpatch_fork()
|
||||
PatchOsFork.unpatch_process_run()
|
||||
except: # noqa
|
||||
# Do not change the exception we need to catch base exception as well
|
||||
pass
|
||||
|
||||
# if a debugger is running, wait for it to attach to the subprocess
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.13.2'
|
||||
__version__ = '1.14.1'
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
import sys
|
||||
|
||||
from clearml import Task
|
||||
from clearml.automation import DiscreteParameterRange, HyperParameterOptimizer, UniformIntegerParameterRange
|
||||
|
||||
|
||||
try:
|
||||
from clearml.automation.optuna import OptimizerOptuna # noqa
|
||||
except ImportError:
|
||||
print("Multi-objective HPO is currently only supported via Optuna")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
task = Task.init(
|
||||
project_name="Hyper-Parameter Optimization",
|
||||
task_name="Multi-objective HPO",
|
||||
task_type=Task.TaskTypes.optimizer,
|
||||
reuse_last_task_id=False,
|
||||
)
|
||||
|
||||
# experiment template to optimize in the hyper-parameter optimization
|
||||
args = {
|
||||
"template_task_id": None,
|
||||
"run_as_service": False,
|
||||
}
|
||||
args = task.connect(args)
|
||||
|
||||
# Get the template task experiment that we want to optimize
|
||||
if not args["template_task_id"]:
|
||||
args["template_task_id"] = Task.get_task(project_name="examples", task_name="Keras HP optimization base").id
|
||||
|
||||
# Set default queue name for the Training tasks themselves.
|
||||
# later can be overridden in the UI
|
||||
execution_queue = "1xGPU"
|
||||
|
||||
an_optimizer = HyperParameterOptimizer(
|
||||
# This is the experiment we want to optimize
|
||||
base_task_id=args["template_task_id"],
|
||||
# here we define the hyper-parameters to optimize
|
||||
# Notice: The parameter name should exactly match what you see in the UI: <section_name>/<parameter>
|
||||
# For Example, here we see in the base experiment a section Named: "General"
|
||||
# under it a parameter named "batch_size", this becomes "General/batch_size"
|
||||
# If you have `argparse` for example, then arguments will appear under the "Args" section,
|
||||
# and you should instead pass "Args/batch_size"
|
||||
hyper_parameters=[
|
||||
UniformIntegerParameterRange("General/layer_1", min_value=128, max_value=512, step_size=128),
|
||||
UniformIntegerParameterRange("General/layer_2", min_value=128, max_value=512, step_size=128),
|
||||
DiscreteParameterRange("General/batch_size", values=[96, 128, 160]),
|
||||
DiscreteParameterRange("General/epochs", values=[30]),
|
||||
],
|
||||
# this is the objectives' metric/series we want to maximize/minimize
|
||||
objective_metric_title=["evaluate", "evaluate"],
|
||||
objective_metric_series=["score", "accuracy"],
|
||||
# now we decide if we want to maximize it or minimize them
|
||||
# in this case, we want to minimize evaluate/score and maximize evaluate/accuracy
|
||||
objective_metric_sign=["min", "max"],
|
||||
# let us limit the number of concurrent experiments,
|
||||
# this in turn will make sure we do dont bombard the scheduler with experiments.
|
||||
# if we have an auto-scaler connected, this, by proxy, will limit the number of machine
|
||||
max_number_of_concurrent_tasks=1,
|
||||
# optimizer_class has to be OptimizerOptuna
|
||||
optimizer_class=OptimizerOptuna,
|
||||
# Select an execution queue to schedule the experiments for execution
|
||||
execution_queue=execution_queue,
|
||||
# If specified all Tasks created by the HPO process will be created under the `spawned_project` project
|
||||
spawn_project=None, # 'HPO spawn project',
|
||||
# If specified only the top K performing Tasks will be kept, the others will be automatically archived
|
||||
save_top_k_tasks_only=None, # 5,
|
||||
# Optional: Limit the execution time of a single experiment, in minutes.
|
||||
# (this is optional, and if using OptimizerBOHB, it is ignored)
|
||||
time_limit_per_job=10.0,
|
||||
# Check the experiments every 12 seconds is way too often, we should probably set it to 5 min,
|
||||
# assuming a single experiment is usually hours...
|
||||
pool_period_min=0.2,
|
||||
# set the maximum number of jobs to launch for the optimization, default (None) unlimited
|
||||
# If OptimizerBOHB is used, it defined the maximum budget in terms of full jobs
|
||||
# basically the cumulative number of iterations will not exceed total_max_jobs * max_iteration_per_job
|
||||
total_max_jobs=10,
|
||||
# set the minimum number of iterations for an experiment, before early stopping.
|
||||
# Does not apply for simple strategies such as RandomSearch or GridSearch
|
||||
min_iteration_per_job=10,
|
||||
# Set the maximum number of iterations for an experiment to execute
|
||||
# (This is optional, unless using OptimizerBOHB where this is a must)
|
||||
max_iteration_per_job=30,
|
||||
)
|
||||
|
||||
# if we are running as a service, just enqueue ourselves into the services queue and let it run the optimization
|
||||
if args["run_as_service"]:
|
||||
# if this code is executed by `clearml-agent` the function call does nothing.
|
||||
# if executed locally, the local process will be terminated, and a remote copy will be executed instead
|
||||
task.execute_remotely(queue_name="services", exit_process=True)
|
||||
|
||||
# report every 12 seconds, this is way too often, but we are testing here
|
||||
an_optimizer.set_report_period(0.2)
|
||||
# start the optimization process, callback function to be called every time an experiment is completed
|
||||
# this function returns immediately
|
||||
an_optimizer.start()
|
||||
# You can also use the line below instead to run all the optimizer tasks locally, without using queues or agent
|
||||
# an_optimizer.start_locally(job_complete_callback=job_complete_callback)
|
||||
# set the time limit for the optimization process (2 hours)
|
||||
an_optimizer.set_time_limit(in_minutes=120.0)
|
||||
# wait until process is done (notice we are controlling the optimization process in the background)
|
||||
an_optimizer.wait()
|
||||
# optimization is completed, print the top performing experiments id
|
||||
top_exp = an_optimizer.get_top_experiments(top_k=3)
|
||||
print([t.id for t in top_exp])
|
||||
# make sure background optimization stopped
|
||||
an_optimizer.stop()
|
||||
|
||||
print("We are done, good bye")
|
||||
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
24
examples/pipeline/decorated_pipeline_step_decorators.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from clearml import PipelineDecorator
|
||||
|
||||
|
||||
def our_decorator(func):
|
||||
def function_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs) + 1
|
||||
return function_wrapper
|
||||
|
||||
|
||||
@PipelineDecorator.component()
|
||||
@our_decorator
|
||||
def step():
|
||||
return 1
|
||||
|
||||
|
||||
@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated")
|
||||
def pipeline():
|
||||
result = step()
|
||||
assert result == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PipelineDecorator.run_locally()
|
||||
pipeline()
|
||||
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
27
examples/pipeline/decorated_pipeline_step_functions.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from clearml import PipelineController
|
||||
|
||||
|
||||
def our_decorator(func):
|
||||
def function_wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs) + 1
|
||||
return function_wrapper
|
||||
|
||||
|
||||
@our_decorator
|
||||
def step():
|
||||
return 1
|
||||
|
||||
|
||||
def evaluate(step_return):
|
||||
assert step_return == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pipeline = PipelineController(name="test_decorated", project="test_decorated")
|
||||
pipeline.add_function_step(name="step", function=step, function_return=["step_return"])
|
||||
pipeline.add_function_step(
|
||||
name="evaluate",
|
||||
function=evaluate,
|
||||
function_kwargs=dict(step_return='${step.step_return}')
|
||||
)
|
||||
pipeline.start_locally(run_pipeline_steps_locally=True)
|
||||
179
examples/pipeline/full_tabular_data_process_pipeline_example.py
Normal file
179
examples/pipeline/full_tabular_data_process_pipeline_example.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from clearml import PipelineDecorator, Task
|
||||
|
||||
|
||||
@PipelineDecorator.component(cache=True)
|
||||
def create_dataset(source_url: str, project: str, dataset_name: str) -> str:
|
||||
print("starting create_dataset")
|
||||
from clearml import StorageManager, Dataset
|
||||
import pandas as pd
|
||||
local_file = StorageManager.get_local_copy(source_url)
|
||||
df = pd.read_csv(local_file, header=None)
|
||||
df.to_csv(path_or_buf="./dataset.csv", index=False)
|
||||
dataset = Dataset.create(dataset_project=project, dataset_name=dataset_name)
|
||||
dataset.add_files("./dataset.csv")
|
||||
dataset.get_logger().report_table(title="sample", series="head", table_plot=df.head())
|
||||
dataset.finalize(auto_upload=True)
|
||||
|
||||
print("done create_dataset")
|
||||
return dataset.id
|
||||
|
||||
|
||||
@PipelineDecorator.component(cache=True)
|
||||
def preprocess_dataset(dataset_id: str):
|
||||
print("starting preprocess_dataset")
|
||||
from clearml import Dataset
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
dataset = Dataset.get(dataset_id=dataset_id)
|
||||
local_folder = dataset.get_local_copy()
|
||||
df = pd.read_csv(Path(local_folder) / "dataset.csv", header=None)
|
||||
# "preprocessing" - adding columns
|
||||
df.columns = [
|
||||
'age', 'workclass', 'fnlwgt', 'degree', 'education-yrs', 'marital-status',
|
||||
'occupation', 'relationship', 'ethnicity', 'gender', 'capital-gain',
|
||||
'capital-loss', 'hours-per-week', 'native-country', 'income-cls',
|
||||
]
|
||||
df.to_csv(path_or_buf="./dataset.csv", index=False)
|
||||
|
||||
# store in a new dataset
|
||||
new_dataset = Dataset.create(
|
||||
dataset_project=dataset.project, dataset_name="{} v2".format(dataset.name),
|
||||
parent_datasets=[dataset]
|
||||
)
|
||||
new_dataset.add_files("./dataset.csv")
|
||||
new_dataset.get_logger().report_table(title="sample", series="head", table_plot=df.head())
|
||||
new_dataset.finalize(auto_upload=True)
|
||||
|
||||
print("done preprocess_dataset")
|
||||
return new_dataset.id
|
||||
|
||||
|
||||
@PipelineDecorator.component(cache=True)
|
||||
def verify_dataset_integrity(dataset_id: str, expected_num_columns: int):
|
||||
print("starting verify_dataset_integrity")
|
||||
from clearml import Dataset, Logger
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
dataset = Dataset.get(dataset_id=dataset_id)
|
||||
local_folder = dataset.get_local_copy()
|
||||
df = pd.read_csv(Path(local_folder) / "dataset.csv")
|
||||
print("Verifying dataset")
|
||||
assert len(df.columns) == expected_num_columns
|
||||
print("PASSED")
|
||||
# log some stats on the age column
|
||||
Logger.current_logger().report_histogram(
|
||||
title="histogram", series="age", values=np.histogram(df["age"])
|
||||
)
|
||||
|
||||
print("done verify_dataset_integrity")
|
||||
return True
|
||||
|
||||
|
||||
@PipelineDecorator.component(output_uri=True)
|
||||
def train_model(dataset_id: str, training_args: dict):
|
||||
print("starting train_model")
|
||||
from clearml import Dataset, OutputModel, Task
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from sklearn.model_selection import train_test_split
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
dataset = Dataset.get(dataset_id=dataset_id)
|
||||
local_folder = dataset.get_local_copy()
|
||||
df = pd.read_csv(Path(local_folder) / "dataset.csv")
|
||||
|
||||
# prepare data (i.e. select specific columns)
|
||||
columns = ["age", "fnlwgt", "education-yrs", "capital-gain", "capital-loss", "hours-per-week"]
|
||||
X = df[columns].drop("age", axis=1)
|
||||
y = df["age"]
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
# create matrix
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dtest = xgb.DMatrix(X_test, label=y_test)
|
||||
|
||||
# train with XGBoost
|
||||
params = {"objective": "reg:squarederror", "eval_metric": "rmse"}
|
||||
bst = xgb.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=training_args.get("num_boost_round", 100),
|
||||
evals=[(dtrain, "train"), (dtest, "test")],
|
||||
verbose_eval=0,
|
||||
)
|
||||
# evaluate
|
||||
y_pred = bst.predict(dtest)
|
||||
plt.plot(y_test, 'r')
|
||||
plt.plot(y_pred, 'b')
|
||||
|
||||
# let's store the eval score
|
||||
error = np.linalg.norm(y_test-y_pred)
|
||||
bst.save_model("a_model.xgb")
|
||||
|
||||
Task.current_task().reload()
|
||||
model_id = Task.current_task().models['output'][-1].id
|
||||
print("done train_model")
|
||||
return dict(error=error, model_id=model_id)
|
||||
|
||||
|
||||
@PipelineDecorator.component(monitor_models=["best"])
|
||||
def select_best_model(models_score: list):
|
||||
print("starting select_best_model:", models_score)
|
||||
from clearml import OutputModel, Task
|
||||
best_model = None
|
||||
for m in models_score:
|
||||
if not best_model or m["error"] < best_model["error"]:
|
||||
best_model = m
|
||||
|
||||
print("The best model is {}".format(best_model))
|
||||
# lets store it on the pipeline
|
||||
best_model = OutputModel(base_model_id=best_model["model_id"])
|
||||
# let's make sure we have it
|
||||
best_model.connect(task=Task.current_task(), name="best")
|
||||
|
||||
print("done select_best_model")
|
||||
return best_model.id
|
||||
|
||||
|
||||
@PipelineDecorator.pipeline(
|
||||
name='xgboost_pipeline',
|
||||
project='xgboost_pipe_demo',
|
||||
version='0.1'
|
||||
)
|
||||
def pipeline(data_url: str, project: str):
|
||||
|
||||
dataset_id = create_dataset(source_url=data_url, project=project, dataset_name="mock")
|
||||
|
||||
preprocessed_dataset_id = preprocess_dataset(dataset_id=dataset_id)
|
||||
|
||||
if not bool(verify_dataset_integrity(
|
||||
dataset_id=preprocessed_dataset_id,
|
||||
expected_num_columns=15)
|
||||
):
|
||||
print("Verification Failed!")
|
||||
return False
|
||||
|
||||
print("start training models")
|
||||
models_score = []
|
||||
for i in [100, 150]:
|
||||
model_score = train_model(
|
||||
dataset_id=preprocessed_dataset_id, training_args=dict(num_boost_round=i)
|
||||
)
|
||||
models_score.append(model_score)
|
||||
|
||||
model_id = select_best_model(models_score=models_score)
|
||||
print("selected model_id = {}".format(model_id))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
|
||||
|
||||
# comment to run the entire pipeline remotely
|
||||
if Task.running_locally():
|
||||
# this is for demonstration purpose only,
|
||||
# it will run the entire pipeline logic and components locally
|
||||
PipelineDecorator.run_locally()
|
||||
|
||||
pipeline(data_url=url, project="xgboost_pipe_demo")
|
||||
@@ -2,4 +2,5 @@ joblib>=0.13.2
|
||||
matplotlib >= 3.1.1 ; python_version >= '3.6'
|
||||
matplotlib >= 2.2.4 ; python_version < '3.6'
|
||||
scikit-learn
|
||||
pandas
|
||||
clearml
|
||||
@@ -46,7 +46,7 @@ def report_plots(logger, iteration=0):
|
||||
|
||||
# report confusion matrix
|
||||
confusion = np.random.randint(10, size=(10, 10))
|
||||
logger.report_matrix(
|
||||
logger.report_confusion_matrix(
|
||||
"example_confusion",
|
||||
"ignored",
|
||||
iteration=iteration,
|
||||
@@ -56,7 +56,7 @@ def report_plots(logger, iteration=0):
|
||||
)
|
||||
|
||||
# report confusion matrix with 0,0 is at the top left
|
||||
logger.report_matrix(
|
||||
logger.report_confusion_matrix(
|
||||
"example_confusion_0_0_at_top",
|
||||
"ignored",
|
||||
iteration=iteration,
|
||||
|
||||
Reference in New Issue
Block a user