This commit is contained in:
revital
2024-01-22 09:44:07 +02:00
33 changed files with 2353 additions and 733 deletions

View File

@@ -69,6 +69,7 @@ class PipelineController(object):
_final_failure = {} # Node.name: bool
_task_template_header = CreateFromFunction.default_task_template_header
_default_pipeline_version = "1.0.0"
_project_section = ".pipelines"
valid_job_status = ["failed", "cached", "completed", "aborted", "queued", "running", "skipped", "pending"]
@@ -176,7 +177,8 @@ class PipelineController(object):
always_create_from_code=True, # type: bool
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> None
"""
@@ -266,6 +268,9 @@ class PipelineController(object):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution when creating
the steps from a functions, otherwise all global imports will be automatically imported in a safe manner at
the beginning of each steps execution. Default is False
"""
if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
@@ -302,10 +307,11 @@ class PipelineController(object):
self._last_progress_update_time = 0
self._artifact_serialization_function = artifact_serialization_function
self._artifact_deserialization_function = artifact_deserialization_function
self._skip_global_imports = skip_global_imports
if not self._task:
task_name = name or project or '{}'.format(datetime.now())
if self._pipeline_as_sub_project:
parent_project = "{}.pipelines".format(project+'/' if project else '')
parent_project = (project + "/" if project else "") + self._project_section
project_name = "{}/{}".format(parent_project, task_name)
else:
parent_project = None
@@ -1422,7 +1428,7 @@ class PipelineController(object):
mutually_exclusive(pipeline_id=pipeline_id, pipeline_project=pipeline_project, _require_at_least_one=False)
mutually_exclusive(pipeline_id=pipeline_id, pipeline_name=pipeline_name, _require_at_least_one=False)
if not pipeline_id:
pipeline_project_hidden = "{}/.pipelines/{}".format(pipeline_project, pipeline_name)
pipeline_project_hidden = "{}/{}/{}".format(pipeline_project, cls._project_section, pipeline_name)
name_with_runtime_number_regex = r"^{}( #[0-9]+)*$".format(re.escape(pipeline_name))
pipelines = Task._query_tasks(
pipeline_project=[pipeline_project_hidden],
@@ -1520,7 +1526,8 @@ class PipelineController(object):
dry_run=True,
task_template_header=self._task_template_header,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
artifact_deserialization_function=self._artifact_deserialization_function,
skip_global_imports=self._skip_global_imports
)
return task_definition
@@ -2725,7 +2732,7 @@ class PipelineController(object):
self._final_failure[node.name] = True
completed_jobs.append(j)
node.executed = node.job.task_id() if not node_failed else False
node.executed = node.job.task_id() if not (node_failed or node.job.is_aborted()) else False
if j in launched_nodes:
launched_nodes.remove(j)
# check if we need to stop all running steps
@@ -3315,7 +3322,8 @@ class PipelineDecorator(PipelineController):
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> ()
"""
@@ -3398,6 +3406,9 @@ class PipelineDecorator(PipelineController):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
global imports will be automatically imported in a safe manner at the beginning of each steps execution.
Default is False
"""
super(PipelineDecorator, self).__init__(
name=name,
@@ -3419,7 +3430,8 @@ class PipelineDecorator(PipelineController):
always_create_from_code=False,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
# if we are in eager execution, make sure parent class knows it
@@ -3482,7 +3494,7 @@ class PipelineDecorator(PipelineController):
else:
self._final_failure[node.name] = True
completed_jobs.append(j)
node.executed = node.job.task_id() if not node_failed else False
node.executed = node.job.task_id() if not (node_failed or node.job.is_aborted()) else False
if j in launched_nodes:
launched_nodes.remove(j)
# check if we need to stop all running steps
@@ -3685,7 +3697,8 @@ class PipelineDecorator(PipelineController):
task_template_header=self._task_template_header,
_sanitize_function=sanitize,
artifact_serialization_function=self._artifact_serialization_function,
artifact_deserialization_function=self._artifact_deserialization_function
artifact_deserialization_function=self._artifact_deserialization_function,
skip_global_imports=self._skip_global_imports
)
return task_definition
@@ -3906,10 +3919,12 @@ class PipelineDecorator(PipelineController):
:return: function wrapper
"""
def decorator_wrap(func):
_name = name or str(func.__name__)
# noinspection PyProtectedMember
unwrapped_func = CreateFromFunction._deep_extract_wrapped(func)
_name = name or str(unwrapped_func.__name__)
function_return = return_values if isinstance(return_values, (tuple, list)) else [return_values]
inspect_func = inspect.getfullargspec(func)
inspect_func = inspect.getfullargspec(unwrapped_func)
# add default argument values
if inspect_func.args:
default_values = list(inspect_func.defaults or [])
@@ -4127,7 +4142,7 @@ class PipelineDecorator(PipelineController):
return task.artifacts[return_name].get(
deserialization_function=cls._singleton._artifact_deserialization_function
)
return task.get_parameters(cast=True)[CreateFromFunction.return_section + "/" + return_name]
return task.get_parameters(cast=True).get(CreateFromFunction.return_section + "/" + return_name)
return_w = [LazyEvalWrapper(
callback=functools.partial(result_wrapper, n),
@@ -4172,7 +4187,8 @@ class PipelineDecorator(PipelineController):
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
output_uri=None, # type: Optional[Union[str, bool]]
skip_global_imports=False # type: bool
):
# type: (...) -> Callable
"""
@@ -4286,6 +4302,9 @@ class PipelineDecorator(PipelineController):
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
:param skip_global_imports: If True, global imports will not be included in the steps' execution, otherwise all
global imports will be automatically imported in a safe manner at the beginning of each steps execution.
Default is False
"""
def decorator_wrap(func):
@@ -4332,7 +4351,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references()
@@ -4384,7 +4404,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
output_uri=output_uri,
skip_global_imports=skip_global_imports
)
a_pipeline._args_map = args_map or {}
@@ -4551,6 +4572,13 @@ class PipelineDecorator(PipelineController):
_node.parents = (_node.parents or []) + [
x for x in cls._evaluated_return_values.get(tid, []) if x in leaves
]
if not cls._singleton._abort_running_steps_on_failure:
for parent in _node.parents:
if cls._singleton._nodes[parent].status in ["failed", "aborted", "skipped"]:
_node.skip_job = True
return
for k, v in kwargs.items():
if v is None or isinstance(v, (float, int, bool, six.string_types)):
_node.parameters["{}/{}".format(CreateFromFunction.kwargs_section, k)] = v

View File

@@ -85,7 +85,7 @@ class _TrainsBandsterWorker(Worker):
self.optimizer.budget.iterations.update(self._current_job.task_id(), iteration_value[0])
# check if we exceeded this job budget
if iteration_value[0] >= self.budget_iteration_scale * budget:
if iteration_value[0][0] >= self.budget_iteration_scale * budget:
self._current_job.abort()
break
@@ -95,7 +95,7 @@ class _TrainsBandsterWorker(Worker):
# noinspection PyProtectedMember
self.optimizer.budget.jobs.update(
self._current_job.task_id(),
float(iteration_value[0]) / self.optimizer._max_iteration_per_job)
float(iteration_value[0][0]) / self.optimizer._max_iteration_per_job)
result = {
# this is the a mandatory field to run hyperband
@@ -104,7 +104,7 @@ class _TrainsBandsterWorker(Worker):
# can be used for any user-defined information - also mandatory
'info': self._current_job.task_id()
}
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value))
print('TrainsBandsterWorker result {}, iteration {}'.format(result, iteration_value[0]))
# noinspection PyProtectedMember
self.optimizer._current_jobs.remove(self._current_job)
return result
@@ -299,7 +299,7 @@ class OptimizerBOHB(SearchStrategy, RandomSeed):
sleep_interval=int(self.pool_period_minutes * 60),
budget_iteration_scale=budget_iteration_scale,
base_task_id=self._base_task_id,
objective=self._objective_metric,
objective=self._objective_metric.objectives[0],
queue_name=self._execution_queue,
nameserver='127.0.0.1', nameserver_port=self._nameserver_port, run_id=fake_run_id, id=i)
w.run(background=True)

View File

@@ -7,7 +7,8 @@ from itertools import product
from logging import getLogger
from threading import Thread, Event
from time import time
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable, Tuple
from abc import ABC, abstractmethod
from .job import ClearmlJob, LocalClearmlJob
from .parameters import Parameter
@@ -19,7 +20,33 @@ from ..task import Task
logger = getLogger('clearml.automation.optimization')
class Objective(object):
class _ObjectiveInterface(ABC):
@abstractmethod
def get_objective(self, task_id):
pass
@abstractmethod
def get_current_raw_objective(self, task):
pass
@abstractmethod
def get_objective_sign(self):
pass
@abstractmethod
def get_objective_metric(self):
pass
@abstractmethod
def get_normalized_objective(self, task_id):
pass
@abstractmethod
def get_top_tasks(self, top_k, optimizer_task_id=None, task_filter=None):
pass
class Objective(_ObjectiveInterface):
"""
Optimization ``Objective`` class to maximize / minimize over all experiments. This class will sample a specific
scalar from all experiments, and maximize / minimize over single scalar (i.e., title and series combination).
@@ -53,7 +80,7 @@ class Objective(object):
self.series = series
assert order in ('min', 'max',)
# normalize value so we always look for the highest objective value
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else +1
self.sign = -1 if (isinstance(order, str) and order.lower().strip() == 'min') else 1
self._metric = None
self.extremum = extremum
@@ -243,7 +270,7 @@ class Budget(object):
# type: () -> (Optional[float])
if self.limit is None or not self.current:
return None
return sum(self.current.values())/float(self.limit)
return sum(self.current.values()) / float(self.limit)
def __init__(self, jobs_limit, iterations_limit, compute_time_limit):
# type: (Optional[int], Optional[int], Optional[float]) -> ()
@@ -467,7 +494,7 @@ class SearchStrategy(object):
if self.max_iteration_per_job:
iterations = self._get_job_iterations(job)
if iterations > 0:
if iterations and iterations > 0:
self.budget.iterations.update(job.task_id(), iterations)
if iterations > self.max_iteration_per_job:
abort_job = True
@@ -512,12 +539,8 @@ class SearchStrategy(object):
:return: A list of Task objects, ordered by performance, where index 0 is the best performing Task.
"""
# noinspection PyProtectedMember
top_tasks = self._get_child_tasks(
parent_task_id=self._job_parent_id or self._base_task_id,
order_by=self._objective_metric._get_last_metrics_encode_field(),
additional_filters={'page_size': int(top_k), 'page': 0})
return top_tasks
return self._objective_metric.get_top_tasks(top_k=top_k,
optimizer_task_id=self._job_parent_id or self._base_task_id)
def get_top_experiments_id_metrics_pair(self, top_k, all_metrics=False, only_completed=False):
# type: (int, bool, bool) -> Sequence[(str, dict)]
@@ -582,15 +605,17 @@ class SearchStrategy(object):
# noinspection PyProtectedMember
top_tasks_ids_metric = self._get_child_tasks_ids(
parent_task_id=self._job_parent_id or self._base_task_id,
order_by=self._objective_metric._get_last_metrics_encode_field(),
order_by=self._objective_metric._get_last_metrics_encode_field()[0],
additional_filters=additional_filters,
additional_fields=['last_metrics']
)
title, series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
title_series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
titles = [ts[0] for ts in title_series]
series = [ts[1] for ts in title_series]
return [(i, {'{}/{}'.format(v['metric'], v['variant']): v
for variant in metric.values() for v in variant.values()
if all_metrics or v['metric'] == title and v['variant'] == series}
if all_metrics or (v['metric'] in titles and v['variant'] in series)}
) for i, metric in top_tasks_ids_metric]
def get_top_experiments_details(
@@ -669,15 +694,28 @@ class SearchStrategy(object):
# noinspection PyProtectedMember
top_tasks_ids_metric_params = self._get_child_tasks_ids(
parent_task_id=self._job_parent_id or self._base_task_id,
order_by=self._objective_metric._get_last_metrics_encode_field(),
order_by=self._objective_metric._get_last_metrics_encode_field()[
0] if self._objective_metric.len == 1 else None,
additional_filters=additional_filters,
additional_fields=['last_metrics', 'hyperparams']
)
if self._objective_metric.len != 1:
top_tasks_ids_metric_params_dict = {}
for task in top_tasks_ids_metric_params:
objective = self._objective_metric.get_objective(task[0])
if objective is None or any(o is None for o in objective):
continue
top_tasks_ids_metric_params_dict[task[0]] = (objective, task)
# noinspection PyProtectedMember
sorted_ids = self._objective_metric._sort_jobs_by_domination(top_tasks_ids_metric_params_dict)
top_tasks_ids_metric_params = [top_tasks_ids_metric_params_dict[s][1] for s in sorted_ids]
# get hp_parameters:
hp_params = set(p.name for p in self._hyper_parameters)
title, series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
title_series = self._objective_metric.get_objective_metric() if not all_metrics else (None, None)
titles = [ts[0] for ts in title_series]
series = [ts[1] for ts in title_series]
return [
{
'task_id': tid,
@@ -688,19 +726,20 @@ class SearchStrategy(object):
},
'metrics': {
'{}/{}'.format(v['metric'], v['variant']): v for variant in metric.values()
for v in variant.values() if all_metrics or v['metric'] == title and v['variant'] == series
for v in variant.values() if all_metrics or v['metric'] in titles and v['variant'] in series
}
} for tid, metric, param_sections in top_tasks_ids_metric_params
]
def get_objective_metric(self):
# type: () -> (str, str)
# type: () -> Union[Tuple[str, str], List[Tuple[str, str]]]
"""
Return the metric title, series pair of the objective.
:return: (title, series)
"""
return self._objective_metric.get_objective_metric()
objective = self._objective_metric.get_objective_metric()
return objective[0] if self._objective_metric.len == 1 else objective
def helper_create_job(
self,
@@ -823,7 +862,9 @@ class SearchStrategy(object):
def _get_job_iterations(self, job):
# type: (Union[ClearmlJob, Task]) -> int
iteration_value = self._objective_metric.get_current_raw_objective(job)
return iteration_value[0] if iteration_value else -1
if iteration_value is not None and any(iv is not None and iv[0] is not None for iv in iteration_value):
return max(iv[0] for iv in iteration_value if iv is not None)
return -1
@classmethod
def _get_child_tasks_ids(
@@ -887,7 +928,7 @@ class SearchStrategy(object):
task_objects = Task._query_tasks(**task_filter)
if not additional_fields:
return [t.id for t in task_objects]
return [[t.id]+[getattr(t, f, None) for f in additional_fields] for t in task_objects]
return [[t.id] + [getattr(t, f, None) for f in additional_fields] for t in task_objects]
@classmethod
def _get_child_tasks(
@@ -927,6 +968,146 @@ class SearchStrategy(object):
]
class MultiObjective(_ObjectiveInterface):
def __init__(self, title, series, order, extremum):
self.title = title
self.series = series
self.order = order
self.extremum = extremum
self.objectives = []
for title_, series_, order_, extremum_ in zip(title, series, order, extremum):
self.objectives.append(Objective(title=title_, series=series_, order=order_, extremum=extremum_))
self.len = len(self.objectives)
def get_objective(self, task_id):
# type: (Union[str, Task, ClearmlJob]) -> Optional[List[float]]
"""
Return a specific task scalar values based on the objective settings (title/series).
:param str task_id: The Task ID to retrieve scalar from (or ``ClearMLJob`` object).
:return: The scalar values.
"""
objective = [o.get_objective(task_id) for o in self.objectives]
if any(o is None for o in objective):
return None
return objective
def get_current_raw_objective(self, task):
# type: (Union[ClearmlJob, Task]) -> Optional[List[Tuple[int, float]]]
"""
Return the current raw value (without sign normalization) of each objective.
:param str task: The Task or Job to retrieve scalar from (or ``ClearmlJob`` object).
:return: List[Optional[Tuple(iteration, value)]]. None if the metric does not exist.
"""
objective = [o.get_current_raw_objective(task) for o in self.objectives]
if any(o is None for o in objective):
return None
return objective
def get_objective_sign(self):
# type: () -> List[float]
"""
Return the sign of the objectives.
- ``+1`` - If maximizing
- ``-1`` - If minimizing
:return: Objective function signs.
"""
return [o.get_objective_sign() for o in self.objectives]
def get_normalized_objective(self, task_id):
# type: (Union[str, Task, ClearmlJob]) -> Optional[List[float]]
"""
Return a normalized task scalar values based on the objective settings (title/series).
I.e. objective is always to maximize the returned value
:param str task_id: The Task ID to retrieve scalars from.
:return: Normalized scalar values.
"""
objective = [o.get_normalized_objective(task_id) for o in self.objectives]
if any(o is None for o in objective):
return None
return objective
def get_objective_metric(self):
# type: () -> List[(str, str)]
"""
Return the metric title, series pairs of the objectives.
:return: List[(title, series)]
"""
return [o.get_objective_metric() for o in self.objectives]
def get_top_tasks(self, top_k, optimizer_task_id=None, task_filter=None):
# type: (int, Optional[str], Optional[dict]) -> Sequence[Task]
"""
Return a list of Tasks of the top performing experiments.
If there is only one objective, the tasks are sorted based on that objective.
If there are multiple objectives, the tasks are sorted based on successive Pareto fronts.
A trial is located at the Pareto front if there are no trials that dominate the trial.
A trial dominates another trial if all its objective metrics are greater or equal than the other
trial's and there is at least one objective metric that is strictly greater than the other.
:param int top_k: The number of Tasks (experiments) to return.
:param str optimizer_task_id: Parent optimizer Task ID
:param dict task_filter: Optional task_filtering for the query
:return: A list of Task objects, ordered by performance, where index 0 is the best performing Task.
"""
if self.len == 1:
return self.objectives[0].get_top_tasks(
top_k=top_k,
optimizer_task_id=optimizer_task_id,
task_filter=task_filter
)
task_filter = deepcopy(task_filter) if task_filter else {}
if optimizer_task_id:
task_filter["parent"] = optimizer_task_id
# noinspection PyProtectedMember
tasks = Task._query_tasks(**task_filter)
candidates = {}
for task in tasks:
values = self.get_objective(task.id)
if values is None or any(v is None for v in values):
continue
candidates[task.id] = (values, task.id)
sorted_ids = self._sort_jobs_by_domination(candidates)
if not sorted_ids:
return []
return Task.get_tasks(task_ids=sorted_ids[:top_k])
def _get_last_metrics_encode_field(self):
# noinspection PyProtectedMember
return [o._get_last_metrics_encode_field() for o in self.objectives]
def _weakly_dominates_normalized(self, lhs, rhs):
return all(lhs_elem * o.sign >= rhs_elem * o.sign for lhs_elem, rhs_elem, o in zip(lhs, rhs, self.objectives))
@staticmethod
def _dominates(lhs, rhs):
return all(lhs_elem >= rhs_elem for lhs_elem, rhs_elem in zip(lhs, rhs)) and \
any(lhs_elem > rhs_elem for lhs_elem, rhs_elem in zip(lhs, rhs))
def _sort_jobs_by_domination(self, jobs):
job_ids = list(jobs.keys())
job_ids_sorted = []
while len(job_ids_sorted) < len(jobs.keys()):
have_result = False
for job_id in job_ids:
if all(self._weakly_dominates_normalized(jobs[job_id][0], jobs[other_job_id][0])
for other_job_id in job_ids):
have_result = True
job_ids_sorted.append(job_id)
if not have_result:
job_ids_sorted.extend(job_ids)
job_ids = [job_id for job_id in job_ids if job_id not in job_ids_sorted]
return job_ids_sorted
class GridSearch(SearchStrategy):
"""
Grid search strategy controller. Full grid sampling of every hyperparameter combination.
@@ -1089,12 +1270,12 @@ class HyperParameterOptimizer(object):
self,
base_task_id, # type: str
hyper_parameters, # type: Sequence[Parameter]
objective_metric_title, # type: str
objective_metric_series, # type: str
objective_metric_sign='min', # type: str
objective_metric_title, # type: Union[str, Sequence[str]]
objective_metric_series, # type: Union[str, Sequence[str]]
objective_metric_sign="min", # type: Union[str, Sequence[str]]
optimizer_class=RandomSearch, # type: Union[SearchStrategy, type(SearchStrategy)]
max_number_of_concurrent_tasks=10, # type: int
execution_queue='default', # type: str
execution_queue="default", # type: str
optimization_time_limit=None, # type: Optional[float]
compute_time_limit=None, # type: Optional[float]
auto_connect_task=True, # type: Union[bool, Task]
@@ -1109,10 +1290,14 @@ class HyperParameterOptimizer(object):
:param str base_task_id: The Task ID to be used as template experiment to optimize.
:param list hyper_parameters: The list of Parameter objects to optimize over.
:param str objective_metric_title: The Objective metric title to maximize / minimize (for example,
``validation``).
:param str objective_metric_series: The Objective metric series to maximize / minimize (for example, ``loss``).
:param str objective_metric_sign: The objective to maximize / minimize.
:param Union[str, Sequence[str]] objective_metric_title: The Objective metric title(s) to maximize / minimize
(for example, ``validation``, ``["validation", "loss"]``). If ``objective_metric_title`` is a sequence
(used to optimize multiple objectives at the same time), then ``objective_metric_series`` and
``objective_metric_sign`` have to be sequences of the same length. Each title will be matched
with the respective series and sign
:param Union[str, Sequence[str]] objective_metric_series: The Objective metric series to maximize / minimize
(for example, ``loss_series``, ``["validation_series", "loss_series"]``).
:param Union[str, Sequence[str]] objective_metric_sign: The objectives to maximize / minimize.
The values are:
- ``min`` - Minimize the last reported value for the specified title/series scalar.
@@ -1190,7 +1375,24 @@ class HyperParameterOptimizer(object):
# make sure we stop all jobs
an_optimizer.stop()
"""
if type(objective_metric_title) is not type(objective_metric_series) or type(
objective_metric_title) is not type(objective_metric_sign):
raise TypeError(
"objective_metric_series, objective_metric_title and objective_metric_sign have to be of the same type"
" (strings if doing single objective optimization and lists of the same length"
" if doing multi-objective optimization)"
)
if isinstance(objective_metric_title, str):
objective_metric_series = [objective_metric_series]
objective_metric_title = [objective_metric_title]
objective_metric_sign = [objective_metric_sign]
if len(objective_metric_series) != len(objective_metric_title) or len(objective_metric_series) != len(
objective_metric_sign
):
raise ValueError(
"Can not use multiple objective optimization when objective_metric_series, objective_metric_title"
" or objective_metric_sign do not have the same length"
)
# create a new Task, if we do not have one already
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
self._readonly_task = \
@@ -1224,17 +1426,29 @@ class HyperParameterOptimizer(object):
self.hyper_parameters = hyper_parameters
self.max_number_of_concurrent_tasks = opts['max_number_of_concurrent_tasks']
self.execution_queue = opts['execution_queue']
self.objective_metric = Objective(
title=opts['objective_metric_title'], series=opts['objective_metric_series'],
order='min' if opts['objective_metric_sign'] in ('min', 'min_global') else 'max',
extremum=opts['objective_metric_sign'].endswith('_global'))
self._objective_metric = MultiObjective(
title=opts["objective_metric_title"],
series=opts["objective_metric_series"],
order=["min" if sign_ in ("min", "min_global") else "max" for sign_ in opts["objective_metric_sign"]],
extremum=[sign_.endswith("_global") for sign_ in opts["objective_metric_sign"]]
)
optuna_error_message = "Multi parameter optimization is only supported via Optuna. Please install Optuna via" + \
" `pip install optuna and set the `optimizer_class` to `clearml.automation.optuna.OptimizerOptuna`"
try:
if self._objective_metric.len != 1:
from .optuna import OptimizerOptuna
if optimizer_class != OptimizerOptuna:
raise ValueError(optuna_error_message)
except Exception:
raise ValueError(optuna_error_message)
# if optimizer_class is an instance, use it as is.
if type(optimizer_class) != type:
if not isinstance(optimizer_class, type):
self.optimizer = optimizer_class
else:
self.optimizer = optimizer_class(
base_task_id=opts['base_task_id'], hyper_parameters=hyper_parameters,
objective_metric=self.objective_metric, execution_queue=opts['execution_queue'],
objective_metric=self._objective_metric, execution_queue=opts['execution_queue'],
num_concurrent_workers=opts['max_number_of_concurrent_tasks'],
compute_time_limit=opts['compute_time_limit'], **opts.get('optimizer_kwargs', {}))
self.optimizer.set_optimizer_task(self._task)
@@ -1607,9 +1821,9 @@ class HyperParameterOptimizer(object):
@classmethod
def get_optimizer_top_experiments(
cls,
objective_metric_title, # type: str
objective_metric_series, # type: str
objective_metric_sign, # type: str
objective_metric_title, # type: Union[str, List[str]]
objective_metric_series, # type: Union[str, List[str]]
objective_metric_sign, # type: Union[str, List[str]]
optimizer_task_id, # type: str
top_k, # type: int
):
@@ -1636,6 +1850,12 @@ class HyperParameterOptimizer(object):
title=objective_metric_title, series=objective_metric_series, order=objective_metric_sign)
return objective.get_top_tasks(top_k=top_k, optimizer_task_id=optimizer_task_id)
@property
def objective_metric(self):
if self._objective_metric.len == 1:
return self._objective_metric.objectives[0]
return self._objective_metric
def _connect_args(self, optimizer_class=None, hyper_param_configuration=None, **kwargs):
# type: (SearchStrategy, dict, Any) -> (SearchStrategy, list, dict)
if not self._task or self._readonly_task:
@@ -1705,8 +1925,8 @@ class HyperParameterOptimizer(object):
def _report_daemon(self):
# type: () -> ()
title, series = self.objective_metric.get_objective_metric()
title = '{}/{}'.format(title, series)
title_series = self._objective_metric.get_objective_metric()
title = ["{}/{}".format(ts[0], ts[1]) for ts in title_series]
counter = 0
completed_jobs = dict()
task_logger = None
@@ -1722,10 +1942,14 @@ class HyperParameterOptimizer(object):
params["status"] = str(task.status)
# noinspection PyProtectedMember
iteration_value = task.get_last_iteration()
objective = self.objective_metric.get_objective(task)
objective = self._objective_metric.get_objective(task)
completed_jobs[task.id] = (
objective if objective is not None else -1,
iteration_value if iteration_value is not None else -1,
objective if objective is not None else (
[-1] * self._objective_metric.len
),
iteration_value if iteration_value is not None else (
[-1] * self._objective_metric.len
),
params
)
@@ -1754,9 +1978,9 @@ class HyperParameterOptimizer(object):
self._report_remaining_budget(task_logger, counter)
if (
self.optimizer.budget.compute_time.used
and self.optimizer.budget.compute_time.limit
and self.optimizer.budget.compute_time.used >= self.optimizer.budget.compute_time.limit
self.optimizer.budget.compute_time.used
and self.optimizer.budget.compute_time.limit
and self.optimizer.budget.compute_time.used >= self.optimizer.budget.compute_time.limit
):
logger.warning(
"Optimizer task reached compute time limit (used {:.2f} out of {:.2f})".format(
@@ -1767,8 +1991,9 @@ class HyperParameterOptimizer(object):
self._report_resources(task_logger, counter)
# collect a summary of all the jobs and their final objective values
cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - \
{j.task_id() for j in self.optimizer.get_running_jobs()}
cur_completed_jobs = set(self.optimizer.get_created_jobs_ids().keys()) - {
j.task_id() for j in self.optimizer.get_running_jobs()
}
self._report_completed_status(completed_jobs, cur_completed_jobs, task_logger, title)
self._report_completed_tasks_best_results(set(completed_jobs.keys()), task_logger, title, counter)
@@ -1790,10 +2015,14 @@ class HyperParameterOptimizer(object):
def _report_completed_status(self, completed_jobs, cur_completed_jobs, task_logger, title, force=False):
job_ids_sorted_by_objective = self.__sort_jobs_by_objective(completed_jobs)
best_experiment = \
(self.objective_metric.get_normalized_objective(job_ids_sorted_by_objective[0]),
job_ids_sorted_by_objective[0]) \
if job_ids_sorted_by_objective else (float('-inf'), None)
best_experiment = (
(
self._objective_metric.get_normalized_objective(job_ids_sorted_by_objective[0]),
job_ids_sorted_by_objective[0],
)
if job_ids_sorted_by_objective
else ([float("-inf")], None)
)
if force or cur_completed_jobs != set(completed_jobs.keys()):
pairs = []
labels = []
@@ -1801,13 +2030,14 @@ class HyperParameterOptimizer(object):
created_jobs_tasks = self.optimizer.get_created_jobs_tasks()
id_status = {j_id: j_run.status() for j_id, j_run in created_jobs_tasks.items()}
for i, (job_id, params) in enumerate(created_jobs.items()):
value = self.objective_metric.get_objective(job_id)
value = self._objective_metric.get_objective(job_id)
if job_id in completed_jobs:
if value != completed_jobs[job_id][0]:
iteration_value = self.objective_metric.get_current_raw_objective(job_id)
iteration_value = self._objective_metric.get_current_raw_objective(job_id)
iteration = [it_[0] if it_ else -1 for it_ in iteration_value]
completed_jobs[job_id] = (
value,
iteration_value[0] if iteration_value else -1,
iteration,
copy(dict(status=id_status.get(job_id), **params)))
elif completed_jobs.get(job_id):
completed_jobs[job_id] = (completed_jobs[job_id][0],
@@ -1815,43 +2045,98 @@ class HyperParameterOptimizer(object):
copy(dict(status=id_status.get(job_id), **params)))
pairs.append((i, completed_jobs[job_id][0]))
labels.append(str(completed_jobs[job_id][2])[1:-1])
elif value is not None:
elif value is not None and all(v is not None for v in value):
pairs.append((i, value))
labels.append(str(params)[1:-1])
iteration_value = self.objective_metric.get_current_raw_objective(job_id)
iteration_value = self._objective_metric.get_current_raw_objective(job_id)
iteration = [it_[0] if it_ else -1 for it_ in iteration_value]
completed_jobs[job_id] = (
value,
iteration_value[0] if iteration_value else -1,
iteration,
copy(dict(status=id_status.get(job_id), **params)))
# callback new experiment completed
if self._experiment_completed_cb:
normalized_value = self.objective_metric.get_normalized_objective(job_id)
if normalized_value is not None and normalized_value > best_experiment[0]:
normalized_value = self._objective_metric.get_normalized_objective(job_id)
if self._objective_metric.len == 1 and normalized_value is not None and \
normalized_value[0] > best_experiment[0][0]:
best_experiment = normalized_value, job_id
elif self._objective_metric.len != 1 and normalized_value is not None and \
all(n is not None for n in normalized_value) and (best_experiment[0] == float("-inf") or
MultiObjective._dominates(
normalized_value,
best_experiment[0])): # noqa
best_experiment = normalized_value, job_id
c = completed_jobs[job_id]
self._experiment_completed_cb(job_id, c[0], c[1], c[2], best_experiment[1])
if pairs:
print('Updating job performance summary plot/table')
# update scatter plot
task_logger.report_scatter2d(
title='Optimization Objective', series=title,
scatter=pairs, iteration=0, labels=labels,
mode='markers', xaxis='job #', yaxis='objective')
print("Updating job performance summary plot/table")
if isinstance(title, list):
for i, title_ in enumerate(title):
# update scatter plot
task_logger.report_scatter2d(
title="Optimization Objective",
series=title_,
scatter=[(p[0], p[1][i]) for p in pairs],
iteration=0,
labels=labels,
mode="markers",
xaxis="job #",
yaxis="objective",
)
else:
task_logger.report_scatter2d(
title="Optimization Objective",
series=title,
scatter=pairs,
iteration=0,
labels=labels,
mode="markers",
xaxis="job #",
yaxis="objective",
)
# update summary table
job_ids = list(completed_jobs.keys())
job_ids_sorted_by_objective = sorted(
job_ids, key=lambda x: completed_jobs[x][0], reverse=bool(self.objective_metric.sign >= 0))
job_ids_sorted_by_objective = self.__sort_jobs_by_objective(completed_jobs)
# sort the columns except for 'objective', 'iteration'
columns = list(sorted(set([c for k, v in completed_jobs.items() for c in v[2].keys()])))
# add the index column (task id) and the first two columns 'objective', 'iteration' then the rest
table_values = [['task id', 'objective', 'iteration'] + columns]
table_values += \
[([job, completed_jobs[job][0], completed_jobs[job][1]] +
[completed_jobs[job][2].get(c, '') for c in columns]) for job in job_ids_sorted_by_objective]
concat_iterations = True
if self._objective_metric.len == 1:
# add the index column (task id) and the first two columns 'objective', 'iteration' then the rest
table_values = [['task id', 'objective', 'iteration'] + columns]
table_values += \
[([job, completed_jobs[job][0][0], completed_jobs[job][1][0]] +
[completed_jobs[job][2].get(c, '') for c in columns]) for job in job_ids_sorted_by_objective]
else:
table_values = ['task id']
for job in job_ids_sorted_by_objective:
if not all(iter_ == completed_jobs[job][1][0] for iter_ in completed_jobs[job][1]):
concat_iterations = False
break
if concat_iterations:
for objective in self._objective_metric.objectives:
table_values.append(objective.title + "/" + objective.series)
table_values.append("iteration")
table_values = [table_values + columns]
for job in job_ids_sorted_by_objective:
entry = [job]
for val in completed_jobs[job][0]:
entry += [val]
entry += [completed_jobs[job][1][0]]
entry += [completed_jobs[job][2].get(c, '') for c in columns]
table_values.append(entry)
else:
for objective in self._objective_metric.objectives:
table_values.append(objective.title + "/" + objective.series)
table_values.append("iteration " + objective.title + "/" + objective.series)
table_values = [table_values + columns]
for job in job_ids_sorted_by_objective:
entry = [job]
for val, iter_ in zip(completed_jobs[job][0], completed_jobs[job][1]):
entry += [val, iter_]
entry += [completed_jobs[job][2].get(c, '') for c in columns]
table_values.append(entry)
# create links for task id in the table
task_link_template = self._task.get_output_log_web_page() \
@@ -1867,15 +2152,42 @@ class HyperParameterOptimizer(object):
task_link_template.format(project=project_id, task=task_id), task_id)
task_logger.report_table(
"summary", "job", 0, table_plot=table_values_with_links,
extra_layout={"title": "objective: {}".format(title)})
"summary",
"job",
0,
table_plot=table_values_with_links,
extra_layout={
"title": "objective: {}".format(title if not isinstance(title, list) else ", ".join(title))
},
)
# Build parallel Coordinates: convert to columns, and reorder accordingly
if len(table_values) > 1:
table_values_columns = [[row[i] for row in table_values] for i in range(len(table_values[0]))]
table_values_columns = \
[[table_values_columns[0][0]] + [c[:6]+'...' for c in table_values_columns[0][1:]]] + \
table_values_columns[2:-1] + [[title]+table_values_columns[1][1:]]
if self._objective_metric.len == 1:
table_values_columns = \
[[table_values_columns[0][0]] + [c[:6] + '...' for c in table_values_columns[0][1:]]] + \
table_values_columns[2:-1] + [[title] + table_values_columns[1][1:]]
else:
if not concat_iterations:
new_table_values_columns = []
handled = []
for i in range(1, 2 * len(self._objective_metric.objectives), 2):
handled.append(i)
new_table_values_columns.append(table_values_columns[i])
prefix = []
for i in range(len(table_values_columns)):
if i in handled or table_values_columns[i][0] == "status":
continue
prefix.append(table_values_columns[i])
table_values_columns = prefix + new_table_values_columns
else:
table_values_columns = ([table_values_columns[0]] +
table_values_columns[len(self._objective_metric.objectives) + 1:-1] +
table_values_columns[1:len(self._objective_metric.objectives) + 1]
)
for i in range(len(table_values_columns[0]) - 1):
table_values_columns[0][i + 1] = table_values_columns[0][i + 1][:6] + "..."
pcc_dims = []
for col in table_values_columns:
# test if all values are numbers:
@@ -1896,16 +2208,22 @@ class HyperParameterOptimizer(object):
pcc_dims.append(d)
# report parallel coordinates
plotly_pcc = dict(
data=[dict(
type='parcoords',
line=dict(colorscale='Viridis',
reversescale=bool(self.objective_metric.sign >= 0),
color=table_values_columns[-1][1:]),
dimensions=pcc_dims)],
layout={})
task_logger.report_plotly(
title='Parallel Coordinates', series='',
iteration=0, figure=plotly_pcc)
data=[
dict(
type="parcoords",
line=dict(
colorscale="Viridis",
reversescale=(
self._objective_metric.len == 1 and self._objective_metric.objectives[0].sign >= 0,
),
color=table_values_columns[-1][1:],
),
dimensions=pcc_dims,
)
],
layout={},
)
task_logger.report_plotly(title="Parallel Coordinates", series="", iteration=0, figure=plotly_pcc)
# upload summary as artifact
if force:
@@ -1937,21 +2255,20 @@ class HyperParameterOptimizer(object):
if not completed_jobs:
return
value_func, series_name = (max, "max") if self.objective_metric.get_objective_sign() > 0 else \
(min, "min")
latest_completed, obj_values = self._get_latest_completed_task_value(completed_jobs, series_name)
if latest_completed:
val = value_func(obj_values)
task_logger.report_scalar(
title=title,
series=series_name,
iteration=counter,
value=val)
task_logger.report_scalar(
title=title,
series="last reported",
iteration=counter,
value=latest_completed)
objectives = self._objective_metric.objectives
if not isinstance(title, list):
title = [title]
for objective, title_ in zip(objectives, title):
value_func, series_name = (max, "max") if objective.get_objective_sign() > 0 else (min, "min")
latest_completed, obj_values = self._get_latest_completed_task_value(
completed_jobs, series_name, objective.title, objective.series
)
if latest_completed:
val = value_func(obj_values)
task_logger.report_scalar(title=title_, series=series_name, iteration=counter, value=val)
task_logger.report_scalar(
title=title_, series="last reported", iteration=counter, value=latest_completed
)
def _report_resources(self, task_logger, iteration):
# type: (Logger, int) -> ()
@@ -1990,7 +2307,7 @@ class HyperParameterOptimizer(object):
title="resources", series=series,
iteration=iteration, value=val)
def _get_latest_completed_task_value(self, cur_completed_jobs, series_name):
def _get_latest_completed_task_value(self, cur_completed_jobs, series_name, title, series):
# type: (Set[str], str) -> (float, List[float])
completed_value = None
latest_completed = None
@@ -2003,16 +2320,15 @@ class HyperParameterOptimizer(object):
continue
completed_time = datetime_from_isoformat(response.response_data["task"]["completed"].partition("+")[0])
completed_time = completed_time.timestamp()
completed_values = self._get_last_value(response)
completed_values = self._get_last_value(response, title, series)
obj_values.append(completed_values['max_value'] if series_name == "max" else completed_values['min_value'])
if not latest_completed or completed_time > latest_completed:
latest_completed = completed_time
completed_value = completed_values['value']
return completed_value, obj_values
def _get_last_value(self, response):
metrics, title, series, values = ClearmlJob.get_metric_req_params(self.objective_metric.title,
self.objective_metric.series)
def _get_last_value(self, response, title, series):
metrics, title, series, values = ClearmlJob.get_metric_req_params(title, series)
last_values = response.response_data["task"]['last_metrics'][title][series]
return last_values
@@ -2061,6 +2377,14 @@ class HyperParameterOptimizer(object):
def __sort_jobs_by_objective(self, completed_jobs):
if not completed_jobs:
return []
job_ids_sorted_by_objective = list(sorted(
completed_jobs.keys(), key=lambda x: completed_jobs[x][0], reverse=bool(self.objective_metric.sign >= 0)))
return job_ids_sorted_by_objective
if self._objective_metric.len != 1:
# noinspection PyProtectedMember
return self._objective_metric._sort_jobs_by_domination(completed_jobs)
else:
return list(
sorted(
completed_jobs.keys(),
key=lambda x: completed_jobs[x][0],
reverse=bool(self._objective_metric.objectives[0].sign >= 0),
)
)

View File

@@ -55,21 +55,29 @@ class OptunaObjective(object):
if not is_pending:
# noinspection PyProtectedMember
iteration_value = self.optimizer._objective_metric.get_current_raw_objective(current_job)
if not iteration_value:
if not self.optimizer.monitor_job(current_job):
break
continue
# make sure we skip None objective values
if iteration_value and iteration_value[1] is not None:
# update budget
trial.report(value=iteration_value[1], step=iteration_value[0])
if not any(val is None or val[1] is None for val in iteration_value):
iteration = max(iv[0] for iv in iteration_value)
# trial pruning based on intermediate values not supported when using multi-objective
# noinspection PyProtectedMember
if self.optimizer._objective_metric.len == 1:
# update budget
trial.report(value=iteration_value[0][1], step=iteration)
# Handle pruning based on the intermediate value.
if trial.should_prune() and (
not self.min_iteration_per_job or
iteration_value[0] >= self.min_iteration_per_job):
current_job.abort()
raise optuna.TrialPruned()
# Handle pruning based on the intermediate value.
if trial.should_prune() and (
not self.min_iteration_per_job or
iteration >= self.min_iteration_per_job):
current_job.abort()
raise optuna.TrialPruned()
# check if we exceeded this job budget
if self.max_iteration_per_job and iteration_value[0] >= self.max_iteration_per_job:
if self.max_iteration_per_job and iteration >= self.max_iteration_per_job:
current_job.abort()
break
@@ -79,6 +87,10 @@ class OptunaObjective(object):
# noinspection PyProtectedMember
objective_metric = self.optimizer._objective_metric.get_objective(current_job)
# noinspection PyProtectedMember
if self.optimizer._objective_metric.len == 1:
objective_metric = objective_metric[0]
iteration_value = iteration_value[0]
print('OptunaObjective result metric={}, iteration {}'.format(objective_metric, iteration_value))
# noinspection PyProtectedMember
self.optimizer._current_jobs.remove(current_job)
@@ -157,13 +169,22 @@ class OptimizerOptuna(SearchStrategy):
This function returns only after optimization is completed or :meth:`stop` was called.
"""
self._study = optuna.create_study(
direction="minimize" if self._objective_metric.get_objective_sign() < 0 else "maximize",
load_if_exists=False,
sampler=self._optuna_sampler,
pruner=self._optuna_pruner,
study_name=self._optimizer_task.id if self._optimizer_task else None,
)
if self._objective_metric.len != 1:
self._study = optuna.create_study(
directions=["minimize" if sign_ < 0 else "maximize" for sign_ in self._objective_metric.get_objective_sign()],
load_if_exists=False,
sampler=self._optuna_sampler,
pruner=self._optuna_pruner,
study_name=self._optimizer_task.id if self._optimizer_task else None,
)
else:
self._study = optuna.create_study(
direction="minimize" if self._objective_metric.get_objective_sign()[0] < 0 else "maximize",
load_if_exists=False,
sampler=self._optuna_sampler,
pruner=self._optuna_pruner,
study_name=self._optimizer_task.id if self._optimizer_task else None,
)
config_space = self._convert_hyper_parameters_to_optuna()
self._objective = OptunaObjective(
base_task_id=self._base_task_id,

View File

@@ -9,6 +9,7 @@ from attr import attrs, attrib
from dateutil.relativedelta import relativedelta
from .job import ClearmlJob
from .controller import PipelineController
from ..backend_interface.util import datetime_from_isoformat, datetime_to_isoformat, mutually_exclusive
from ..task import Task
@@ -59,6 +60,23 @@ class BaseScheduleJob(object):
self._executed_instances = []
self._executed_instances.append(str(task_id))
def get_resolved_target_project(self):
if not self.base_task_id or not self.target_project:
return self.target_project
# noinspection PyBroadException
try:
task = Task.get_task(task_id=self.base_task_id)
# noinspection PyProtectedMember
if (
PipelineController._tag in task.get_system_tags()
and "/{}/".format(PipelineController._project_section) not in self.target_project
):
# noinspection PyProtectedMember
return "{}/{}/{}".format(self.target_project, PipelineController._project_section, task.name)
except Exception:
pass
return self.target_project
@attrs
class ScheduleJob(BaseScheduleJob):
@@ -367,7 +385,13 @@ class BaseScheduler(object):
:param queue: Remote queue to run the scheduler on, default 'services' queue.
"""
# make sure we serialize the current state if we are running locally
if self._task.running_locally():
self._serialize_state()
self._serialize()
# launch on the remote agent
self._task.execute_remotely(queue_name=queue, exit_process=True)
# we will be deserializing the state inside `start`
self.start()
def _update_execution_plots(self):
@@ -447,7 +471,7 @@ class BaseScheduler(object):
task_overrides=job.task_overrides,
disable_clone_task=not job.clone_task,
allow_caching=False,
target_project=job.target_project,
target_project=job.get_resolved_target_project(),
tags=[add_tags] if add_tags and isinstance(add_tags, str) else add_tags,
)
self._log('Scheduling Job {}, Task {} on queue {}.'.format(

View File

@@ -38,6 +38,7 @@ from .defs import (
from .request import Request, BatchRequest # noqa: F401
from .token_manager import TokenManager
from ..utils import get_http_session_with_retry, urllib_log_warning_setup
from ...backend_config import ConfigurationError
from ...backend_config.defs import get_config_file
from ...debugging import get_logger
from ...debugging.log import resolve_logging_level
@@ -773,7 +774,7 @@ class Session(TokenManager):
# noinspection PyBroadException
try:
cls()
except MissingConfigError:
except (MissingConfigError, ConfigurationError):
if raise_error and not ENV_IGNORE_MISSING_CONFIG.get():
raise
except LoginError:

View File

@@ -1,5 +1,4 @@
import base64
from distutils.util import strtobool
from typing import Union, Optional, Any, TypeVar, Callable, Tuple
import six
@@ -14,6 +13,22 @@ except ImportError:
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError("invalid truth value %r" % (val,))
def base64_to_text(value):
# type: (Any) -> Text
return base64.b64decode(value).decode("utf-8")

View File

@@ -528,8 +528,10 @@ class _Arguments(object):
"Failed parsing task parameter {}={} keeping default {}={}".format(k, param, k, v)
)
# assume more general purpose type int -> float
if v_type == int:
# if parameter is empty and default value is None, keep as None
if param == '' and v is None:
v_type = type(None)
elif v_type == int: # assume more general purpose type int -> float
if v is not None and int(v) != float(v):
v_type = float
elif v_type == bool:

View File

@@ -17,16 +17,14 @@ from ...task import Task
class CreateAndPopulate(object):
_VCS_SSH_REGEX = \
"^" \
"(?:(?P<user>{regular}*?)@)?" \
"(?P<host>{regular}*?)" \
":" \
"(?P<path>{regular}.*)?" \
"$" \
.format(
regular=r"[^/@:#]"
)
_VCS_SSH_REGEX = (
"^"
"(?:(?P<user>{regular}*?)@)?"
"(?P<host>{regular}*?)"
":"
"(?P<path>{regular}.*)?"
"$".format(regular=r"[^/@:#]")
)
def __init__(
self,
@@ -199,9 +197,9 @@ class CreateAndPopulate(object):
# if there is nothing to populate, return
if not any([
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
):
self.folder, self.commit, self.branch, self.repo, self.script, self.cwd,
self.packages, self.requirements_file, self.base_task_id] + (list(self.docker.values()))
):
return task
# clear the script section
@@ -219,7 +217,7 @@ class CreateAndPopulate(object):
if self.cwd:
self.cwd = self.cwd
cwd = self.cwd if Path(self.cwd).is_dir() else (
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
Path(repo_info.script['repo_root']) / self.cwd).as_posix()
if not Path(cwd).is_dir():
raise ValueError("Working directory \'{}\' could not be found".format(cwd))
cwd = Path(cwd).relative_to(repo_info.script['repo_root']).as_posix()
@@ -577,6 +575,7 @@ if __name__ == '__main__':
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
_sanitize_function=None, # type: Optional[Callable[[str], str]]
_sanitize_helper_functions=None, # type: Optional[Callable[[str], str]]
skip_global_imports=False # type: bool
):
# type: (...) -> Optional[Dict, Task]
"""
@@ -659,6 +658,9 @@ if __name__ == '__main__':
return dill.loads(bytes_)
:param _sanitize_function: Sanitization function for the function string.
:param _sanitize_helper_functions: Sanitization function for the helper function string.
:param skip_global_imports: If True, the global imports will not be fetched from the function's file, otherwise
all global imports will be automatically imported in a safe manner at the beginning of the function's
execution. Default is False
:return: Newly created Task object
"""
# not set -> equals True
@@ -671,7 +673,7 @@ if __name__ == '__main__':
assert (not auto_connect_arg_parser or isinstance(auto_connect_arg_parser, (bool, dict)))
function_source, function_name = CreateFromFunction.__extract_function_information(
a_function, sanitize_function=_sanitize_function
a_function, sanitize_function=_sanitize_function, skip_global_imports=skip_global_imports
)
# add helper functions on top.
for f in (helper_functions or []):
@@ -846,11 +848,133 @@ if __name__ == '__main__':
return function_source
@staticmethod
def __extract_function_information(function, sanitize_function=None):
# type: (Callable, Optional[Callable]) -> (str, str)
function_name = str(function.__name__)
function_source = inspect.getsource(function)
def __extract_imports(func):
def add_import_guard(import_):
return ("try:\n "
+ import_.replace("\n", "\n ", import_.count("\n") - 1)
+ "\nexcept Exception as e:\n print('Import error: ' + str(e))\n"
)
# noinspection PyBroadException
try:
import ast
func_module = inspect.getmodule(func)
source = inspect.getsource(func_module)
parsed_source = ast.parse(source)
imports = []
for parsed_source_entry in parsed_source.body:
# we only include global imports (i.e. at col_offset 0)
if parsed_source_entry.col_offset != 0:
continue
if isinstance(parsed_source_entry, ast.ImportFrom):
for sub_entry in parsed_source_entry.names:
import_str = "from {} import {}".format(parsed_source_entry.module, sub_entry.name)
if sub_entry.asname:
import_str += " as {}".format(sub_entry.asname)
imports.append(import_str)
elif isinstance(parsed_source_entry, ast.Import):
for sub_entry in parsed_source_entry.names:
import_str = "import {}".format(sub_entry.name)
if sub_entry.asname:
import_str += " as {}".format(sub_entry.asname)
imports.append(import_str)
imports = [add_import_guard(import_) for import_ in imports]
return "\n".join(imports)
except Exception as e:
getLogger().warning('Could not fetch function imports: {}'.format(e))
return ""
@staticmethod
def _extract_wrapped(decorated):
if not decorated.__closure__:
return None
closure = (c.cell_contents for c in decorated.__closure__)
if not closure:
return None
return next((c for c in closure if inspect.isfunction(c)), None)
@staticmethod
def _deep_extract_wrapped(decorated):
while True:
# noinspection PyProtectedMember
func = CreateFromFunction._extract_wrapped(decorated)
if not func:
return decorated
decorated = func
@staticmethod
def __sanitize(func_source, sanitize_function=None):
if sanitize_function:
function_source = sanitize_function(function_source)
function_source = CreateFromFunction.__sanitize_remove_type_hints(function_source)
return function_source, function_name
func_source = sanitize_function(func_source)
return CreateFromFunction.__sanitize_remove_type_hints(func_source)
@staticmethod
def __get_func_members(module):
result = []
try:
import ast
source = inspect.getsource(module)
parsed = ast.parse(source)
for f in parsed.body:
if isinstance(f, ast.FunctionDef):
result.append(f.name)
except Exception as e:
name = getattr(module, "__name__", module)
getLogger().warning('Could not fetch function declared in {}: {}'.format(name, e))
return result
@staticmethod
def __get_source_with_decorators(func, original_module=None, sanitize_function=None):
if original_module is None:
original_module = inspect.getmodule(func)
func_members = CreateFromFunction.__get_func_members(original_module)
try:
func_members_dict = dict(inspect.getmembers(original_module, inspect.isfunction))
except Exception as e:
name = getattr(original_module, "__name__", original_module)
getLogger().warning('Could not fetch functions from {}: {}'.format(name, e))
func_members_dict = {}
decorated_func = CreateFromFunction._deep_extract_wrapped(func)
decorated_func_source = CreateFromFunction.__sanitize(
inspect.getsource(decorated_func),
sanitize_function=sanitize_function
)
try:
import ast
parsed_decorated = ast.parse(decorated_func_source)
for body_elem in parsed_decorated.body:
if not isinstance(body_elem, ast.FunctionDef):
continue
for decorator in body_elem.decorator_list:
name = None
if isinstance(decorator, ast.Name):
name = decorator.id
elif isinstance(decorator, ast.Call):
name = decorator.func.id
if not name:
continue
decorator_func = func_members_dict.get(name)
if name not in func_members or not decorator_func:
continue
decorated_func_source = CreateFromFunction.__get_source_with_decorators(
decorator_func,
original_module=original_module,
sanitize_function=sanitize_function
) + "\n\n" + decorated_func_source
break
except Exception as e:
getLogger().warning('Could not fetch full definition of function {}: {}'.format(func.__name__, e))
return decorated_func_source
@staticmethod
def __extract_function_information(function, sanitize_function=None, skip_global_imports=False):
# type: (Callable, Optional[Callable], bool) -> (str, str)
function = CreateFromFunction._deep_extract_wrapped(function)
function_source = CreateFromFunction.__get_source_with_decorators(function, sanitize_function=sanitize_function)
if not skip_global_imports:
imports = CreateFromFunction.__extract_imports(function)
else:
imports = ""
return imports + "\n" + function_source, function.__name__

View File

@@ -97,7 +97,12 @@ class ScriptRequirements(object):
for fname, lines in skimage.items():
modules.add('scikit_image', fname, lines)
# if we have torch and it supports tensorboard, we should add that as well
if 'tensorflow-intel' in modules:
tfmodule = modules.pop('tensorflow-intel', {})
for fname, lines in tfmodule.items():
modules.add('tensorflow', fname, lines)
# if we have torch, and it supports tensorboard, we should add that as well
# (because it will not be detected automatically)
if 'torch' in modules and 'tensorboard' not in modules and 'tensorboardX' not in modules:
# noinspection PyBroadException
@@ -331,14 +336,14 @@ class _JupyterObserver(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from nbconvert.exporters import PythonExporter
from nbconvert.exporters import PythonExporter # noqa
_script_exporter = PythonExporter()
except Exception:
_script_exporter = None
if _script_exporter is None:
# noinspection PyPackageRequirements
from nbconvert.exporters.script import ScriptExporter
from nbconvert.exporters.script import ScriptExporter # noqa
_script_exporter = ScriptExporter()
except Exception as ex:
@@ -617,7 +622,7 @@ class ScriptInfo(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from notebook.notebookapp import list_running_servers # <= Notebook v6
from notebook.notebookapp import list_running_servers # noqa <= Notebook v6
# noinspection PyBroadException
try:
jupyter_servers += list(list_running_servers())
@@ -632,7 +637,7 @@ class ScriptInfo(object):
# noinspection PyBroadException
try:
# noinspection PyPackageRequirements
from jupyter_server.serverapp import list_running_servers
from jupyter_server.serverapp import list_running_servers # noqa
# noinspection PyBroadException
try:
jupyter_servers += list(list_running_servers())
@@ -719,7 +724,7 @@ class ScriptInfo(object):
is_google_colab = False
log_history = False
colab_name = None
# check if this is google.colab, then there is no local file
# check if this is `google.colab`, then there is no local file
is_google_colab = ScriptInfo.is_google_colab()
if is_google_colab:
@@ -748,7 +753,7 @@ class ScriptInfo(object):
if not entry_point.exists():
# noinspection PyBroadException
try:
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5])+'.ipynb'
alternative_entry_point = '-'.join(entry_point_filename.split('-')[:-5]) + '.ipynb'
# now we should try to find the actual file
entry_point_alternative = (Path.cwd() / alternative_entry_point).absolute()
if not entry_point_alternative.is_file():
@@ -823,7 +828,7 @@ class ScriptInfo(object):
# returns tuple (notebook name, raw string notebook)
# None, None if fails
try:
from google.colab import _message
from google.colab import _message # noqa
notebook = _message.blocking_request('get_ipynb', timeout_sec=timeout)['ipynb']
notebook_name = notebook.get("metadata", {}).get("colab", {}).get("name", "colab.ipynb")
@@ -990,6 +995,10 @@ class ScriptInfo(object):
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
entry_point = cls._detect_distributed_execution(entry_point, log)
if check_uncommitted:
# if we have a jupyter notebook, always store the entire notebook (instead of the git diff)
if jupyter_filepath:
@@ -1005,7 +1014,7 @@ class ScriptInfo(object):
if len(diff) > cls.max_diff_size_bytes:
messages.append(
"======> WARNING! Git diff too large to store "
"({}kb), skipping uncommitted changes <======".format(len(diff)//1024))
"({}kb), skipping uncommitted changes <======".format(len(diff) // 1024))
auxiliary_git_diff = diff
diff = '# WARNING! git diff too large to store, clear this section to execute without it.\n' \
'# full git diff available in Artifacts/auxiliary_git_diff\n' \
@@ -1060,8 +1069,54 @@ class ScriptInfo(object):
return (ScriptInfoResult(script=script_info, warning_messages=messages, auxiliary_git_diff=auxiliary_git_diff),
script_requirements)
@classmethod
def _detect_distributed_execution(cls, entry_point, log):
# check if we are running with torch distributed, or transformers accelerate
# make sure we change the entry point to reflect it.
is_torch_distributed = os.environ.get("TORCHELASTIC_RUN_ID") is not None
is_transformers_distributed = os.environ.get("ACCELERATE_DYNAMO_MODE") is not None
if not is_torch_distributed and not is_transformers_distributed:
return entry_point
# this torch distributed
# noinspection PyBroadException
try:
from psutil import Process # noqa
cmdline = Process().parent().cmdline()
# first find the torch model call "torch.distributed.run" or "torch.distributed.launch"
if is_torch_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("torch.distributed."))
elif is_transformers_distributed:
cmdstart_i = next(i for i, c in enumerate(cmdline) if c.lower().startswith("accelerate.commands."))
else:
raise Exception() # we should not get here
cmdline = cmdline[cmdstart_i:]
# reverse look into the paths
cmdend_i = next(i for i, c in enumerate(cmdline) if Path(c).stem == Path(entry_point).stem)
filearg = cmdline[cmdend_i]
# notice --args (script args) are passed on the Args section, we skip detecting them here
# we are also already removing the filearg from the cmd (it is the last before script args)
new_cmd = cmdline[:cmdend_i]
# we assume our entrypoint is the last parameter of the execution cmd line
if Path(filearg).stem == Path(entry_point).stem:
entry_point = "-m {} {}".format(" ".join(new_cmd), entry_point)
if log:
log.info(
"{} execution detected: adjusting entrypoint to "
"reflect distributed execution arguments".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate")
)
except Exception:
if log:
log.warning("{} execution detected: Failed Detecting launch arguments, skipping".format(
"Torch Distributed" if is_torch_distributed else "Transformers Accelerate"))
return entry_point
@staticmethod
def __legacy_jupyter_notebook_server_json_parsing(self):
def __legacy_jupyter_notebook_server_json_parsing():
# noinspection PyBroadException
try:
# on some jupyter notebook versions this function can crash on parsing the json file,

View File

@@ -1464,52 +1464,6 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
execution.docker_cmd = image + (' {}'.format(arguments) if arguments else '')
self._edit(execution=execution)
def set_packages(self, packages):
# type: (Union[str, Sequence[str]]) -> ()
"""
Manually specify a list of required packages or a local requirements.txt file.
:param packages: The list of packages or the path to the requirements.txt file.
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt"
"""
if not packages:
return
if not isinstance(packages, str) or not os.path.exists(packages):
# noinspection PyProtectedMember
self._update_requirements(packages)
return
with open(packages) as f:
# noinspection PyProtectedMember
self._update_requirements([line.strip() for line in f.readlines()])
def set_repo(self, repo, branch=None, commit=None):
# type: (str, Optional[str], Optional[str]) -> ()
"""
Specify a repository to attach to the function.
Allow users to execute the task inside the specified repository, enabling them to load modules/script
from the repository. Notice the execution work directory will be the repository root folder.
Supports both git repo url link, and local repository path (automatically converted into the remote
git/commit as is currently checkout).
Example remote url: 'https://github.com/user/repo.git'.
Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy.
:param repo: Remote URL for the repository to use, OR path to local copy of the git repository
Example: 'https://github.com/allegroai/clearml.git' or '~/project/repo'
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
"""
if not repo:
return
with self._edit_lock:
self.reload()
self.data.script.repository = repo
if branch:
self.data.script.branch = branch
if commit:
self.data.script.version_num = commit
self._edit(script=self.data.script)
def get_base_docker(self):
# type: () -> str
"""Get the base Docker command (image) that is set for this experiment."""
@@ -2326,7 +2280,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
import pkg_resources
except ImportError:
get_logger("task").warning(
"Requirement file %s skipped since pkg_resources is not installed" % package_name)
"Requirement file `{}` skipped since pkg_resources is not installed".format(package_name))
else:
with Path(package_name).open() as requirements_txt:
for req in pkg_resources.parse_requirements(requirements_txt):

View File

@@ -175,10 +175,11 @@ class PatchOsFork(object):
try:
return PatchOsFork._original_process_run(self, *args, **kwargs)
finally:
if task:
if task and patched_worker:
try:
if patched_worker:
# remove at exit hooks, we will deadlock when the
# noinspection PyProtectedMember
if task._report_subprocess_enabled:
# just in case, remove at exit hooks, we will deadlock when the
# main Pool manager will terminate this process, and it will...
# noinspection PyProtectedMember
task._at_exit_called = True
@@ -214,12 +215,30 @@ class PatchOsFork(object):
if not task:
return
if not Task._report_subprocess_enabled:
# https://stackoverflow.com/a/34507557
# NOTICE: subprocesses do not exit through exit we have to register signals
if task._Task__exit_hook:
task._Task__exit_hook.register_signal_and_exception_hooks()
else:
# noinspection PyProtectedMember
task._remove_signal_hooks()
# noinspection PyProtectedMember
if Task._report_subprocess_enabled:
# noinspection PyProtectedMember
task._remove_exception_hooks()
PatchOsFork._current_task = task
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
# noinspection PyProtectedMember
if not bool(task._report_subprocess_enabled):
BackgroundMonitor.start_all(task=task)
# if we are reporting into a subprocess, no need to further patch the exit functions
if Task._report_subprocess_enabled:
return
# The signal handler method is Not enough, for the time being, we have both
# even though it makes little sense
# # if we got here patch the os._exit of our instance to call us
@@ -244,6 +263,10 @@ class PatchOsFork(object):
# noinspection PyProtectedMember, PyUnresolvedReferences
os._org_exit = os._exit
# noinspection PyProtectedMember
# https://stackoverflow.com/a/34507557
# NOTICE: subprocesses do not exit through exit, and in most cases not with _exit,
# this means at_exit calls are Not registered respected
os._exit = _at_exit_callback
@staticmethod
@@ -261,3 +284,23 @@ class PatchOsFork(object):
PatchOsFork._fork_callback_after_child()
return ret
@staticmethod
def unpatch_fork():
try:
if PatchOsFork._original_fork and os._exit != PatchOsFork._original_fork:
os._exit = PatchOsFork._original_fork
PatchOsFork._original_fork = None
except Exception:
pass
@staticmethod
def unpatch_process_run():
try:
from multiprocessing.process import BaseProcess
if PatchOsFork._original_process_run and BaseProcess.run != PatchOsFork._original_process_run:
BaseProcess.run = PatchOsFork._original_process_run
PatchOsFork._original_process_run = None
except Exception:
pass

View File

@@ -131,7 +131,7 @@ class PatchHydra(object):
# noinspection PyBroadException
try:
override = PatchHydra._parse_override(override)
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
group_exists = hydra_context.config_loader.repository.group_exists(override.key_or_group)
return group_exists
except Exception:
if not PatchHydra._config_group_warning_sent:

View File

@@ -31,6 +31,7 @@ class PatchJsonArgParse(object):
_special_fields = ["config", "subcommand"]
_section_name = "Args"
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
_ignore_ui_overrides = "_ignore_ui_overrides_"
__remote_task_params = {}
__remote_task_params_dict = {}
__patched = False
@@ -43,7 +44,7 @@ class PatchJsonArgParse(object):
cls.patch(task)
@classmethod
def patch(cls, task):
def patch(cls, task=None):
if ArgumentParser is None:
return
PatchJsonArgParse._update_task_args()
@@ -72,7 +73,9 @@ class PatchJsonArgParse(object):
if not verify_basic_type(v, basic_types=(float, int, bool, str, type(None))) and v:
# noinspection PyBroadException
try:
if isinstance(v, Namespace) or (isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)):
if isinstance(v, Namespace) or (
isinstance(v, list) and all(isinstance(sub_v, Namespace) for sub_v in v)
):
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_namespace(v))
args_type[key_with_section] = PatchJsonArgParse.namespace_type
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
@@ -86,9 +89,9 @@ class PatchJsonArgParse(object):
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
if have_config_file:
cls._current_task.set_parameter(
cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
cls._section_name + cls._args_sep + cls._ignore_ui_overrides,
False,
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority", # noqa
)
@staticmethod
@@ -110,8 +113,6 @@ class PatchJsonArgParse(object):
@staticmethod
def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
return original_fn(obj, *args, **kwargs)
if len(args) == 1:
kwargs["args"] = args[0]
args = []
@@ -124,13 +125,19 @@ class PatchJsonArgParse(object):
params_namespace = Namespace()
for k, v in params.items():
params_namespace[k] = v
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
allow_jsonargparse_overrides_value = True
if PatchJsonArgParse._allow_jsonargparse_overrides in params:
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides)
if PatchJsonArgParse._ignore_ui_overrides in params:
allow_jsonargparse_overrides_value = not params.pop(PatchJsonArgParse._ignore_ui_overrides)
if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args(
obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
obj, params_namespace, subcommand=params_namespace.get(PatchJsonArgParse._command_name)
)
if PatchJsonArgParse._allow_jsonargparse_overrides in params_namespace:
del params_namespace[PatchJsonArgParse._allow_jsonargparse_overrides]
if PatchJsonArgParse._ignore_ui_overrides in params_namespace:
del params_namespace[PatchJsonArgParse._ignore_ui_overrides]
return params_namespace
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
@@ -149,6 +156,7 @@ class PatchJsonArgParse(object):
except ImportError:
try:
import pytorch_lightning
lightning = pytorch_lightning
except ImportError:
lightning = None
@@ -178,20 +186,14 @@ class PatchJsonArgParse(object):
params_dict = t.get_parameters(backwards_compatibility=False, cast=True)
for key, section_param in cls.__remote_task_params[cls._section_name].items():
if section_param.type == cls.namespace_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_namespace_from_json(section_param.value)
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_namespace_from_json(section_param.value)
elif section_param.type == cls.path_type:
params_dict[
"{}/{}".format(cls._section_name, key)
] = cls._get_path_from_json(section_param.value)
params_dict["{}/{}".format(cls._section_name, key)] = cls._get_path_from_json(section_param.value)
elif (not section_param.type or section_param.type == "NoneType") and not section_param.value:
params_dict["{}/{}".format(cls._section_name, key)] = None
skip = len(cls._section_name) + 1
cls.__remote_task_params_dict = {
k[skip:]: v
for k, v in params_dict.items()
if k.startswith(cls._section_name + cls._args_sep)
k[skip:]: v for k, v in params_dict.items() if k.startswith(cls._section_name + cls._args_sep)
}
cls.__update_remote_task_params_dict_based_on_paths(parser)
@@ -200,9 +202,7 @@ class PatchJsonArgParse(object):
paths = PatchJsonArgParse.__get_paths_from_dict(cls.__remote_task_params_dict)
for path in paths:
args = PatchJsonArgParse.__get_args_from_path(
parser,
path,
subcommand=cls.__remote_task_params_dict.get("subcommand")
parser, path, subcommand=cls.__remote_task_params_dict.get("subcommand")
)
for subarg_key, subarg_value in args.items():
if subarg_key not in cls.__remote_task_params_dict:
@@ -222,7 +222,12 @@ class PatchJsonArgParse(object):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
(
(subcommand + PatchJsonArgParse._commands_sep)
if k not in PatchJsonArgParse._special_fields
else ""
)
+ k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg
@@ -252,3 +257,7 @@ class PatchJsonArgParse(object):
if isinstance(json_, list):
return [Path(**dict_) for dict_ in json_]
return Path(**json_)
# patch jsonargparse before anything else
PatchJsonArgParse.patch()

View File

@@ -186,7 +186,15 @@ def get_node_id(default=0):
if node_id is None and (mpi_world_rank is not None or mpi_rank is not None):
node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank
# if node is is till None, use the default
# if node is still None, use the global RANK
if node_id is None:
# noinspection PyBroadException
try:
node_id = int(os.environ.get("RANK"))
except Exception:
pass
# if node is still None, use the default
if node_id is None:
node_id = default

View File

@@ -355,6 +355,20 @@ def stdout_print(*args, **kwargs):
sys.stdout.write(line)
def debug_print(*args, **kwargs):
"""
Print directly to stdout, with process and timestamp from last print call
Example: [pid=123, t=0.003] message here
"""
global tic
tic = globals().get('tic', time.time())
stdout_print(
"\033[1;33m[pid={}, t={:.04f}] ".format(os.getpid(), time.time()-tic)
+ str(args[0] if len(args) == 1 else ("" if not args else args)) + "\033[0m", **kwargs
)
tic = time.time()
if __name__ == '__main__':
# from clearml import Task
# task = Task.init(project_name="examples", task_name="trace test")

View File

@@ -26,7 +26,7 @@ from .storage.helper import StorageHelper
from .utilities.plotly_reporter import SeriesInfo
# Make sure that DeprecationWarning within this package always gets printed
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
warnings.filterwarnings("always", category=DeprecationWarning, module=__name__)
if TYPE_CHECKING:
@@ -56,9 +56,10 @@ class Logger(object):
"""
SeriesInfo = SeriesInfo
_tensorboard_logging_auto_group_scalars = False
_tensorboard_single_series_per_graph = deferred_config('metrics.tensorboard_single_series_per_graph', False)
_tensorboard_single_series_per_graph = deferred_config("metrics.tensorboard_single_series_per_graph", False)
def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
"""
@@ -66,8 +67,9 @@ class Logger(object):
**Do not construct Logger manually!**
Please use :meth:`Logger.get_current`
"""
assert isinstance(private_task, _Task), \
'Logger object cannot be instantiated externally, use Logger.current_logger()'
assert isinstance(
private_task, _Task
), "Logger object cannot be instantiated externally, use Logger.current_logger()"
super(Logger, self).__init__()
self._task = private_task
self._default_upload_destination = None
@@ -75,16 +77,19 @@ class Logger(object):
self._report_worker = None
self._graph_titles = {}
self._tensorboard_series_force_prefix = None
self._task_handler = TaskHandler(task=self._task, capacity=100) \
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging) else None
self._task_handler = (
TaskHandler(task=self._task, capacity=100)
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging)
else None
)
self._connect_std_streams = connect_stdout or connect_stderr
self._connect_logging = connect_logging
self._default_max_sample_history = None
# Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
disable_urllib3_info = config.get("log.disable_urllib3_info", True)
if disable_urllib3_info and logging.getLogger("urllib3").isEnabledFor(logging.INFO):
logging.getLogger("urllib3").setLevel(logging.WARNING)
if self._task.is_main_task():
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
@@ -112,6 +117,7 @@ class Logger(object):
:return: The Logger object (a singleton) for the current running Task.
"""
from .task import Task
task = Task.current_task()
if not task:
return None
@@ -181,20 +187,20 @@ class Logger(object):
:param name: Metric's name
:param value: Metric's value
"""
return self.report_scalar(title="Summary", series=name, value=value, iteration=-2**31)
return self.report_scalar(title="Summary", series=name, value=value, iteration=-(2**31))
def report_vector(
self,
title, # type: str
series, # type: str
values, # type: Sequence[Union[int, float]]
iteration=None, # type: Optional[int]
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
values, # type: Sequence[Union[int, float]]
iteration=None, # type: Optional[int]
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a vector as (default stacked) histogram.
@@ -229,27 +235,36 @@ class Logger(object):
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
"""
warnings.warn(
":meth:`Logger.report_vector` is deprecated;"
"use :meth:`Logger.report_histogram` instead.",
DeprecationWarning
":meth:`Logger.report_vector` is deprecated; use :meth:`Logger.report_histogram` instead.",
DeprecationWarning,
)
self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout)
return self.report_histogram(
title,
series,
values,
iteration or 0,
labels=labels,
xlabels=xlabels,
xaxis=xaxis,
yaxis=yaxis,
mode=mode,
extra_layout=extra_layout
)
def report_histogram(
self,
title, # type: str
series, # type: str
values, # type: Sequence[Union[int, float]]
iteration=None, # type: Optional[int]
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None, # type: Optional[str]
data_args=None, # type: Optional[dict]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
values, # type: Sequence[Union[int, float]]
iteration=None, # type: Optional[int]
labels=None, # type: Optional[List[str]]
xlabels=None, # type: Optional[List[str]]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
mode=None, # type: Optional[str]
data_args=None, # type: Optional[dict]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a (default grouped) histogram.
@@ -301,21 +316,21 @@ class Logger(object):
xlabels=xlabels,
xtitle=xaxis,
ytitle=yaxis,
mode=mode or 'group',
mode=mode or "group",
data_args=data_args,
layout_config=extra_layout,
)
def report_table(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str]
url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
extra_data=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
table_plot=None, # type: Optional[pd.DataFrame, Sequence[Sequence]]
csv=None, # type: Optional[str]
url=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
extra_data=None, # type: Optional[dict]
):
"""
For explicit reporting, report a table plot.
@@ -374,10 +389,7 @@ class Logger(object):
)
"""
mutually_exclusive(
UsageError, _check_none=True,
table_plot=table_plot, csv=csv, url=url
)
mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
table = table_plot
if url or csv:
if not pd:
@@ -412,16 +424,16 @@ class Logger(object):
)
def report_line_plot(
self,
title, # type: str
series, # type: Sequence[SeriesInfo]
xaxis, # type: str
yaxis, # type: str
mode='lines', # type: str
iteration=None, # type: Optional[int]
reverse_xaxis=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: Sequence[SeriesInfo]
xaxis, # type: str
yaxis, # type: str
mode="lines", # type: str
iteration=None, # type: Optional[int]
reverse_xaxis=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot one or more series as lines.
@@ -464,7 +476,7 @@ class Logger(object):
# if task was not started, we have to start it
self._start_task_if_needed()
self._touch_title_series(title, series[0].name if series else '')
self._touch_title_series(title, series[0].name if series else "")
# noinspection PyProtectedMember
return self._task._reporter.report_line_plot(
title=title,
@@ -479,17 +491,17 @@ class Logger(object):
)
def report_scatter2d(
self,
title, # type: str
series, # type: str
scatter, # type: Union[Sequence[Tuple[float, float]], np.ndarray]
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]]
mode='lines', # type: str
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
scatter, # type: Union[Sequence[Tuple[float, float]], np.ndarray]
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]]
mode="lines", # type: str
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, report a 2d scatter plot.
@@ -558,19 +570,19 @@ class Logger(object):
)
def report_scatter3d(
self,
title, # type: str
series, # type: str
scatter, # type: Union[Sequence[Tuple[float, float, float]], np.ndarray]
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
zaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]]
mode='markers', # type: str
fill=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
scatter, # type: Union[Sequence[Tuple[float, float, float]], np.ndarray]
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
zaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]]
mode="markers", # type: str
fill=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a 3d scatter graph (with markers).
@@ -605,16 +617,9 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
"""
# check if multiple series
multi_series = (
isinstance(scatter, list)
and (
isinstance(scatter[0], np.ndarray)
or (
scatter[0]
and isinstance(scatter[0], list)
and isinstance(scatter[0][0], list)
)
)
multi_series = isinstance(scatter, list) and (
isinstance(scatter[0], np.ndarray)
or (scatter[0] and isinstance(scatter[0], list) and isinstance(scatter[0][0], list))
)
if not multi_series:
@@ -647,18 +652,18 @@ class Logger(object):
)
def report_confusion_matrix(
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a heat-map matrix.
@@ -690,7 +695,7 @@ class Logger(object):
matrix = np.array(matrix)
if extra_layout is None:
extra_layout = {'texttemplate': '%{z}'}
extra_layout = {"texttemplate": "%{z}"}
# if task was not started, we have to start it
self._start_task_if_needed()
@@ -711,17 +716,17 @@ class Logger(object):
)
def report_matrix(
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
yaxis_reversed=False, # type: bool
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, plot a confusion matrix.
@@ -744,30 +749,37 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
"""
warnings.warn(
":meth:`Logger.report_matrix` is deprecated;"
"use :meth:`Logger.report_confusion_matrix` instead.",
":meth:`Logger.report_matrix` is deprecated;" "use :meth:`Logger.report_confusion_matrix` instead.",
DeprecationWarning
)
self._touch_title_series(title, series)
return self.report_confusion_matrix(title, series, matrix, iteration or 0,
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels,
yaxis_reversed=yaxis_reversed,
extra_layout=extra_layout)
return self.report_confusion_matrix(
title,
series,
matrix,
iteration or 0,
xaxis=xaxis,
yaxis=yaxis,
xlabels=xlabels,
ylabels=ylabels,
yaxis_reversed=yaxis_reversed,
extra_layout=extra_layout
)
def report_surface(
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
zaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
camera=None, # type: Optional[Sequence[float]]
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
self,
title, # type: str
series, # type: str
matrix, # type: np.ndarray
iteration=None, # type: Optional[int]
xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str]
zaxis=None, # type: Optional[str]
xlabels=None, # type: Optional[List[str]]
ylabels=None, # type: Optional[List[str]]
camera=None, # type: Optional[Sequence[float]]
comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict]
):
"""
For explicit reporting, report a 3d surface plot.
@@ -821,16 +833,16 @@ class Logger(object):
)
def report_image(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
local_path=None, # type: Optional[str]
image=None, # type: Optional[Union[np.ndarray, Image.Image]]
matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
url=None # type: Optional[str]
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
local_path=None, # type: Optional[str]
image=None, # type: Optional[Union[np.ndarray, Image.Image]]
matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
url=None, # type: Optional[str]
):
"""
For explicit reporting, report an image and upload its contents.
@@ -875,8 +887,7 @@ class Logger(object):
- ``False`` - Do not delete after upload. (default)
"""
mutually_exclusive(
UsageError, _check_none=True,
local_path=local_path or None, url=url or None, image=image, matrix=matrix
UsageError, _check_none=True, local_path=local_path or None, url=url or None, image=image, matrix=matrix
)
if matrix is not None:
warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning)
@@ -902,7 +913,7 @@ class Logger(object):
else:
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
@@ -925,16 +936,16 @@ class Logger(object):
)
def report_media(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
local_path=None, # type: Optional[str]
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
file_extension=None, # type: Optional[str]
max_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
url=None # type: Optional[str]
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
local_path=None, # type: Optional[str]
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
file_extension=None, # type: Optional[str]
max_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
url=None, # type: Optional[str]
):
"""
Report media upload its contents, including images, audio, and video.
@@ -966,8 +977,11 @@ class Logger(object):
"""
mutually_exclusive(
UsageError, _check_none=True,
local_path=local_path or None, url=url or None, stream=stream,
UsageError,
_check_none=True,
local_path=local_path or None,
url=url or None,
stream=stream
)
if stream is not None and not file_extension:
raise ValueError("No file extension provided for stream media upload")
@@ -989,7 +1003,7 @@ class Logger(object):
else:
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
@@ -1009,11 +1023,11 @@ class Logger(object):
)
def report_plotly(
self,
title, # type: str
series, # type: str
figure, # type: Union[Dict, "Figure"] # noqa: F821
iteration=None, # type: Optional[int]
self,
title, # type: str
series, # type: str
figure, # type: Union[Dict, "Figure"] # noqa: F821
iteration=None, # type: Optional[int]
):
"""
Report a ``Plotly`` figure (plot) directly
@@ -1033,7 +1047,7 @@ class Logger(object):
plot = figure if isinstance(figure, dict) else figure.to_plotly_json()
# noinspection PyBroadException
try:
plot['layout']['title'] = series
plot["layout"]["title"] = series
except Exception:
pass
# noinspection PyProtectedMember
@@ -1045,13 +1059,13 @@ class Logger(object):
)
def report_matplotlib_figure(
self,
title, # type: str
series, # type: str
figure, # type: Union[MatplotlibFigure, pyplot]
iteration=None, # type: Optional[int]
report_image=False, # type: bool
report_interactive=True, # type: bool
self,
title, # type: str
series, # type: str
figure, # type: Union[MatplotlibFigure, pyplot]
iteration=None, # type: Optional[int]
report_image=False, # type: bool
report_interactive=True, # type: bool
):
"""
Report a ``matplotlib`` figure / plot directly
@@ -1078,8 +1092,7 @@ class Logger(object):
figure=figure,
iter=iteration or 0,
logger=self,
force_save_as_image=False if report_interactive and not report_image
else ('png' if report_image else True),
force_save_as_image=False if report_interactive and not report_image else ("png" if report_image else True),
)
def set_default_upload_destination(self, uri):
@@ -1201,21 +1214,28 @@ class Logger(object):
return int(UploadEvent._file_history_size)
def report_image_and_upload(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
matrix=None, # type: Optional[Union[np.ndarray, Image.Image]]
max_image_history=None, # type: Optional[int]
delete_after_upload=False # type: bool
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
matrix=None, # type: Optional[Union[np.ndarray, Image.Image]]
max_image_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
):
"""
.. deprecated:: 0.13.0
Use :meth:`Logger.report_image` instead
"""
self.report_image(title=title, series=series, iteration=iteration or 0, local_path=path, image=matrix,
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
self.report_image(
title=title,
series=series,
iteration=iteration or 0,
local_path=path,
image=matrix,
max_image_history=max_image_history,
delete_after_upload=delete_after_upload
)
def capture_logging(self):
# type: () -> "_LoggingContext"
@@ -1224,6 +1244,7 @@ class Logger(object):
:return: a ContextManager
"""
class _LoggingContext(object):
def __init__(self, a_logger):
self.logger = a_logger
@@ -1285,6 +1306,7 @@ class Logger(object):
:param force: If True, all matplotlib figures are converted automatically to non-interactive plots.
"""
from clearml.backend_interface.metrics import Reporter
Reporter.matplotlib_force_report_non_interactive(force=force)
@classmethod
@@ -1327,8 +1349,7 @@ class Logger(object):
try:
return int(level)
except (TypeError, ValueError):
self._task.log.log(level=logging.ERROR,
msg='Logger failed casting log level "%s" to integer' % str(level))
self._task.log.log(level=logging.ERROR, msg='Logger failed casting log level "%s" to integer' % str(level))
return logging.INFO
def _console(self, msg, level=logging.INFO, omit_console=False, force_send=False, *args, **_):
@@ -1356,7 +1377,7 @@ class Logger(object):
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
"console", level=level, fn="", lno=0, func="", msg=msg, args=args, exc_info=None
)
# find the task handler that matches our task
self._task_handler.emit(record)
@@ -1366,7 +1387,8 @@ class Logger(object):
try:
# make sure we are writing to the original stdout
StdStreamPatch.stderr_original_write(
'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg))
'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)
)
except Exception:
pass
else:
@@ -1379,7 +1401,7 @@ class Logger(object):
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
StdStreamPatch.stdout_original_write(str(msg) + '\n')
StdStreamPatch.stdout_original_write(str(msg) + "\n")
except Exception:
pass
else:
@@ -1389,14 +1411,14 @@ class Logger(object):
self._start_task_if_needed()
def _report_image_plot_and_upload(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int]
delete_after_upload=False # type: bool
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
):
"""
Report an image, upload its contents, and present in plots section using plotly
@@ -1418,7 +1440,7 @@ class Logger(object):
self._start_task_if_needed()
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
@@ -1438,13 +1460,13 @@ class Logger(object):
)
def _report_file_and_upload(
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
max_file_history=None, # type: Optional[int]
delete_after_upload=False # type: bool
self,
title, # type: str
series, # type: str
iteration=None, # type: Optional[int]
path=None, # type: Optional[str]
max_file_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool
):
"""
Upload a file and report it as link in the debug images section.
@@ -1465,7 +1487,7 @@ class Logger(object):
self._start_task_if_needed()
upload_uri = self.get_default_upload_destination()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)

View File

@@ -2347,6 +2347,7 @@ class OutputModel(BaseModel):
iteration=None, # type: Optional[int]
update_comment=True, # type: bool
is_package=False, # type: bool
async_enable=True, # type: bool
):
# type: (...) -> str
"""
@@ -2374,6 +2375,8 @@ class OutputModel(BaseModel):
- ``True`` - Update model comment (Default)
- ``False`` - Do not update
:param bool is_package: Mark the weights file as compressed package, usually a zip file.
:param bool async_enable: Whether to upload model in background or to block.
Will raise an error in the main thread if the weights failed to be uploaded or not.
:return: The uploaded URI.
"""
@@ -2421,6 +2424,7 @@ class OutputModel(BaseModel):
target_filename=target_filename or Path(weights_filename).name,
auto_delete_file=auto_delete_file,
iteration=iteration,
async_enable=async_enable
)
# make sure we delete the previous file, if it exists
@@ -2502,7 +2506,7 @@ class OutputModel(BaseModel):
output_uri = model.update_and_upload(
model_file=weights_filename,
task_id=self._task.id,
async_enable=True,
async_enable=async_enable,
target_filename=target_filename,
framework=self.framework or framework,
comment=comment,
@@ -2535,6 +2539,7 @@ class OutputModel(BaseModel):
target_filename=None, # type: Optional[str]
auto_delete_file=True, # type: bool
iteration=None, # type: Optional[int]
async_enable=True, # type: bool
):
# type: (...) -> str
"""
@@ -2559,6 +2564,8 @@ class OutputModel(BaseModel):
- ``False`` - Do not delete
:param int iteration: The iteration number.
:param bool async_enable: Whether to upload model in background or to block.
Will raise an error in the main thread if the weights failed to be uploaded or not.
:return: The uploaded URI for the weights package.
"""
@@ -2626,6 +2633,7 @@ class OutputModel(BaseModel):
target_filename=target_filename or "model_package.zip",
iteration=iteration,
update_comment=False,
async_enable=async_enable
)
# set the model tag (by now we should have a model object) so we know we have packaged file
self._set_package_tag()

View File

@@ -4,12 +4,19 @@ from time import time
from typing import Optional, AnyStr, IO
from ..config import config
try:
from tqdm import tqdm # noqa
except ImportError:
tqdm = None
class ProgressReport(object):
report_upload_chunk_size_mb = None
report_download_chunk_size_mb = None
def __init__(self, verbose, total_size, log, report_chunk_size_mb):
def __init__(self, verbose, total_size, log, report_chunk_size_mb,
description_prefix=None, description_suffix=None,
max_time_between_reports_sec=10.0, report_start=None):
self.current_status_mb = 0.
self.last_reported = 0.
self._tic = time()
@@ -18,45 +25,117 @@ class ProgressReport(object):
self._log = log
self._log_flag = False
self._total_size = total_size
self._description_prefix = description_prefix
self._description_suffix = description_suffix
self._max_time_between_reports_sec = max_time_between_reports_sec
self._report_start = report_start if report_start is not None else bool(tqdm is not None)
self._tqdm = None
self._tqdm_init = False
def close(self, report_completed=False, report_summary=False, report_prefix=None, report_suffix=None):
# call this one when we are done
if self._tqdm is not None:
# if we created a self._tqdm object we need to close it
if report_completed:
self._tqdm.update(
self._tqdm.total - min(self._tqdm.total, self.last_reported)
)
self._tqdm.close()
self._tqdm = None
if report_summary:
self._log.info(
'{} {:.2f} MB successfully {}'.format(
report_prefix or self._description_prefix, self._total_size,
report_suffix or self._description_suffix).strip()
)
def _get_tqdm(self):
if self._tqdm_init:
return self._tqdm
self._tqdm_init = True
# create the tqdm progress bar
if tqdm:
# noinspection PyBroadException
try:
self._tqdm = tqdm(
total=round(float(self._total_size), 2),
# desc="{} {}".format(description_prefix, description_suffix).strip(),
unit="MB",
unit_scale=False,
ncols=80,
bar_format="{bar} {percentage:3.0f}% | {n_fmt}/{total_fmt} MB "
"[{elapsed}<{remaining}, {rate_fmt}{postfix}]: {desc}",
)
except Exception:
# failed initializing TQDM (maybe interface changed?)
self._tqdm = None
return self._tqdm
def __call__(self, chunk_size, *_, **__):
chunk_size /= 1024. * 1024.
self.current_status_mb += chunk_size
last_part = self.current_status_mb - self.last_reported
if self._verbose or (last_part >= self._report_chunk_size):
if (self._verbose or (last_part >= self._report_chunk_size) or
(self.last_reported and self.current_status_mb >= self._total_size-0.01) or
(time()-self._tic > self._max_time_between_reports_sec)):
time_diff = time() - self._tic
self.speed = (last_part / time_diff) if time_diff != 0 else 0
self._report(self._total_size, self.current_status_mb, self.speed)
self._tic = time()
self.last_reported = self.current_status_mb
self._report(self._total_size, self.current_status_mb, self.speed)
def _report(self, total_mb, current_mb, speed_mbps):
# type: (float, float, float) -> None
pass
if self._report_start and self.last_reported <= 0:
# first time - print before initializing the tqdm bar
self._log.info(
"{}: {:.2f}MB {}".format(
self._description_prefix, total_mb, self._description_suffix).strip(" :")
)
# initialize or reuse the bar
_tqdm = self._get_tqdm()
if _tqdm:
# make sure we do not spill over due to rounding
if round(float(current_mb), 2) >= _tqdm.total:
_tqdm.update(_tqdm.total - self.last_reported)
else:
_tqdm.update(current_mb - self.last_reported)
else:
self._log.info(
"{}: {:.2f}MB / {:.2f}MB @ {:.2f}MBs {}".format(
self._description_prefix,
current_mb,
total_mb,
speed_mbps,
self._description_suffix
).strip(" :")
)
class UploadProgressReport(ProgressReport):
def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=None):
def __init__(self, filename, verbose, total_size, log, report_chunk_size_mb=None, report_start=None):
report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
else ProgressReport.report_upload_chunk_size_mb or \
int(config.get("storage.log.report_upload_chunk_size_mb", 5))
super(UploadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
self._filename = filename
def _report(self, total_mb, current_mb, speed_mbps):
# type: (float, float, float) -> None
self._log.info(
'Uploading: %.2fMB / %.2fMB @ %.2fMBs from %s' %
(current_mb, total_mb, speed_mbps, self._filename)
super(UploadProgressReport, self).__init__(
verbose, total_size, log, report_chunk_size_mb,
description_prefix="Uploading", description_suffix="to {}".format(filename),
report_start=report_start,
)
self._filename = filename
@classmethod
def from_stream(cls, stream, filename, verbose, log):
# type: (IO[AnyStr], str, bool, logging.Logger) -> Optional[UploadProgressReport]
if hasattr(stream, 'seek'):
total_size = cls._get_stream_length(stream)
return UploadProgressReport(filename, verbose, total_size, log)
total_size_mb = cls._get_stream_length(stream) // (1024 * 1024)
return UploadProgressReport(filename, verbose, total_size_mb, log)
@classmethod
def from_file(cls, filename, verbose, log):
@@ -78,14 +157,13 @@ class UploadProgressReport(ProgressReport):
class DownloadProgressReport(ProgressReport):
def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=None):
def __init__(self, total_size, verbose, remote_path, log, report_chunk_size_mb=None, report_start=None):
report_chunk_size_mb = report_chunk_size_mb if report_chunk_size_mb is not None \
else ProgressReport.report_download_chunk_size_mb or \
int(config.get("storage.log.report_download_chunk_size_mb", 5))
super(DownloadProgressReport, self).__init__(verbose, total_size, log, report_chunk_size_mb)
super(DownloadProgressReport, self).__init__(
verbose, total_size, log, report_chunk_size_mb,
description_prefix="Downloading", description_suffix="from {}".format(remote_path),
report_start=report_start,
)
self._remote_path = remote_path
def _report(self, total_mb, current_mb, speed_mbps):
# type: (float, float, float) -> None
self._log.info('Downloading: %.2fMB / %.2fMB @ %.2fMBs from %s' %
(current_mb, total_mb, speed_mbps, self._remote_path))

View File

@@ -615,7 +615,11 @@ class _Boto3Driver(_Driver):
def async_download(a_obj, a_stream, cb, cfg):
try:
a_obj.download_fileobj(a_stream, Callback=cb, Config=cfg)
if cb:
cb.close(report_completed=True)
except Exception as ex:
if cb:
cb.close()
(log or self.get_logger()).error('Failed downloading: %s' % ex)
a_stream.close()
@@ -780,8 +784,8 @@ class _GoogleCloudStorageDriver(_Driver):
class _Container(object):
def __init__(self, name, cfg):
try:
from google.cloud import storage
from google.oauth2 import service_account
from google.cloud import storage # noqa
from google.oauth2 import service_account # noqa
except ImportError:
raise UsageError(
'Google cloud driver not found. '
@@ -862,7 +866,7 @@ class _GoogleCloudStorageDriver(_Driver):
object.delete()
except Exception as ex:
try:
from google.cloud.exceptions import NotFound
from google.cloud.exceptions import NotFound # noqa
if isinstance(ex, NotFound):
return False
except ImportError:
@@ -949,7 +953,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
except ImportError:
try:
from azure.storage.blob import BlockBlobService # noqa
from azure.common import AzureHttpError # noqa: F401
from azure.common import AzureHttpError # noqa
self.__legacy = True
except ImportError:
@@ -1193,6 +1197,7 @@ class _AzureBlobServiceStorageDriver(_Driver):
obj.blob_name,
progress_callback=cb,
)
cb.close()
if container.is_legacy():
return blob.content
else:
@@ -1663,7 +1668,7 @@ class _FileStorageDriver(_Driver):
try:
os.unlink(path)
except Exception:
except Exception: # noqa
return False
# # Check and delete all the empty parent folders
@@ -1767,14 +1772,14 @@ class _FileStorageDriver(_Driver):
if six.PY3:
from io import FileIO as file
if isinstance(iterator, (file)):
if isinstance(iterator, file):
get_data = iterator.read
args = (chunk_size,)
else:
get_data = next
args = (iterator,)
data = bytes('')
data = bytes(b'')
empty = False
while not empty or len(data) > 0:
@@ -2320,7 +2325,7 @@ class StorageHelper(object):
return self._get_object_size_bytes(obj, silence_errors)
def _get_object_size_bytes(self, obj, silence_errors=False):
# type: (object) -> [int, None]
# type: (object, bool) -> [int, None]
"""
Auxiliary function for `get_object_size_bytes`.
Get size of the remote object in bytes.
@@ -2448,6 +2453,10 @@ class StorageHelper(object):
stream.seek(0)
except Exception:
pass
if cb:
cb.close(report_completed=not bool(last_ex))
if last_ex:
raise last_ex
@@ -2601,9 +2610,10 @@ class StorageHelper(object):
return direct_access_path
temp_local_path = None
cb = None
try:
if verbose:
self._log.info('Start downloading from %s' % remote_path)
self._log.info("Start downloading from {}".format(remote_path))
if not overwrite_existing and Path(local_path).is_file():
self._log.debug(
'File {} already exists, no need to download, thread id = {}'.format(
@@ -2643,8 +2653,9 @@ class StorageHelper(object):
# if driver supports download with callback, use it (it might be faster)
if hasattr(self._driver, 'download_object'):
# callback
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self._log)
# callback if verbose we already reported download start, no need to do that again
cb = DownloadProgressReport(total_size_mb, verbose, remote_path, self._log,
report_start=True if verbose else None)
self._driver.download_object(obj, temp_local_path, callback=cb)
download_reported = bool(cb.last_reported)
dl_total_mb = cb.current_status_mb
@@ -2686,15 +2697,28 @@ class StorageHelper(object):
raise Exception('Failed renaming partial file, downloaded file exists and a 0-sized file')
# report download if we are on the second chunk
if verbose or download_reported:
if cb:
cb.close(
report_completed=True,
report_summary=verbose or download_reported,
report_prefix="Downloaded",
report_suffix="from {} , saved to {}".format(remote_path, local_path)
)
elif verbose or download_reported:
self._log.info(
'Downloaded %.2f MB successfully from %s , saved to %s' % (dl_total_mb, remote_path, local_path))
"Downloaded {:.2f} MB successfully from {} , saved to {}".format(
dl_total_mb, remote_path, local_path)
)
return local_path
except DownloadError:
if cb:
cb.close()
raise
except Exception as e:
if cb:
cb.close()
self._log.error("Could not download {} , err: {} ".format(remote_path, e))
if delete_on_failure:
if delete_on_failure and temp_local_path:
# noinspection PyBroadException
try:
os.remove(temp_local_path)
@@ -2880,7 +2904,9 @@ class StorageHelper(object):
def _do_async_upload(self, data):
assert isinstance(data, self._UploadData)
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path, extra=data.extra, cb=data.callback, verbose=True, retries=data.retries, return_canonized=data.return_canonized)
return self._do_upload(data.src_path, data.dest_path, data.canonized_dest_path,
extra=data.extra, cb=data.callback, verbose=True,
retries=data.retries, return_canonized=data.return_canonized)
def _upload_from_file(self, local_path, dest_path, extra=None):
if not hasattr(self._driver, 'upload_object'):
@@ -2897,9 +2923,12 @@ class StorageHelper(object):
object_name=object_name,
callback=cb,
extra=extra)
if cb:
cb.close()
return res
def _do_upload(self, src_path, dest_path, canonized_dest_path, extra=None, cb=None, verbose=False, retries=1, return_canonized=False):
def _do_upload(self, src_path, dest_path, canonized_dest_path,
extra=None, cb=None, verbose=False, retries=1, return_canonized=False):
object_name = self._normalize_object_name(canonized_dest_path)
if cb:
try:

View File

@@ -1,9 +1,7 @@
import atexit
import copy
import json
import os
import shutil
import signal
import sys
import threading
import time
@@ -42,7 +40,8 @@ from .backend_config.defs import get_active_config_file, get_config_file
from .backend_api.services import tasks, projects, events
from .backend_api.session.session import (
Session, ENV_ACCESS_KEY, ENV_SECRET_KEY, ENV_HOST, ENV_WEB_HOST, ENV_FILES_HOST, )
from .backend_api.session.defs import ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG, ENV_OFFLINE_MODE, MissingConfigError
from .backend_api.session.defs import (ENV_DEFERRED_TASK_INIT, ENV_IGNORE_MISSING_CONFIG,
ENV_OFFLINE_MODE, MissingConfigError)
from .backend_interface.metrics import Metrics
from .backend_interface.model import Model as BackendModel
from .backend_interface.base import InterfaceBase
@@ -99,13 +98,17 @@ from .utilities.proxy_object import (
from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.lowlevel.distributed import get_torch_local_rank, get_torch_distributed_anchor_task_id, \
create_torch_distributed_anchor
from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.process.exit_hooks import ExitHooks
from .utilities.matching import matches_any_wildcard
from .utilities.parallel import FutureTaskCaller
from .utilities.networking import get_private_ip
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
if TYPE_CHECKING:
import pandas
import numpy
@@ -487,8 +490,6 @@ class Task(_Task):
# unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember
cls.__main_task.__register_at_exit(cls.__main_task._at_exit)
# TODO: Check if the signal handler method is safe enough, for the time being, do not unhook
# cls.__main_task.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# start all reporting threads
BackgroundMonitor.start_all(task=cls.__main_task)
@@ -530,10 +531,16 @@ class Task(_Task):
is_deferred = False
try:
if not running_remotely():
# check remote status
_local_rank = get_torch_local_rank()
if _local_rank is not None and _local_rank > 0:
is_sub_process_task_id = get_torch_distributed_anchor_task_id(timeout=30)
# only allow if running locally and creating the first Task
# otherwise we ignore and perform in order
if ENV_DEFERRED_TASK_INIT.get():
deferred_init = True
if not is_sub_process_task_id and deferred_init and deferred_init != cls.__nested_deferred_init_flag:
def completed_cb(x):
Task.__main_task = x
@@ -574,6 +581,11 @@ class Task(_Task):
not auto_connect_frameworks.get('detect_repository', True)) else True,
auto_connect_streams=auto_connect_streams,
)
# check if we are local rank 0 (local master),
# create an anchor with task ID for the other processes
if _local_rank == 0:
create_torch_distributed_anchor(task_id=task.id)
except MissingConfigError as e:
if not ENV_IGNORE_MISSING_CONFIG.get():
raise
@@ -636,7 +648,11 @@ class Task(_Task):
# register at exist only on the real (none deferred) Task
if not is_deferred:
# register the main task for at exit hooks (there should only be one)
# noinspection PyProtectedMember
task.__register_at_exit(task._at_exit)
# noinspection PyProtectedMember
if cls.__exit_hook:
cls.__exit_hook.register_signal_and_exception_hooks()
# always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork(task)
@@ -1552,6 +1568,69 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def set_packages(self, packages):
# type: (Union[str, Path, Sequence[str]]) -> ()
"""
Manually specify a list of required packages or a local requirements.txt file.
When running remotely this call is ignored
:param packages: The list of packages or the path to the requirements.txt file.
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt" or ""
Use an empty string (packages="") to clear the requirements section (remote execution will use
requirements.txt from the git repository if the file exists)
"""
if running_remotely() or packages is None:
return
self._wait_for_repo_detection(timeout=300.)
if packages and isinstance(packages, (str, Path)) and Path(packages).is_file():
with open(Path(packages).as_posix(), "rt") as f:
# noinspection PyProtectedMember
self._update_requirements([line.strip() for line in f.readlines()])
return
# noinspection PyProtectedMember
self._update_requirements(packages or "")
def set_repo(self, repo=None, branch=None, commit=None):
# type: (Optional[str], Optional[str], Optional[str]) -> ()
"""
Specify a repository to attach to the function.
Allow users to execute the task inside the specified repository, enabling them to load modules/script
from the repository. Notice the execution work directory will be the repository root folder.
Supports both git repo url link, and local repository path (automatically converted into the remote
git/commit as is currently checkout).
Example remote url: "https://github.com/user/repo.git".
Example local repo copy: "./repo" -> will automatically store the remote
repo url and commit ID based on the locally cloned copy.
When executing remotely, this call will not override the repository data (it is ignored)
:param repo: Optional, remote URL for the repository to use, OR path to local copy of the git repository.
Use an empty string to clear the repo.
Example: "https://github.com/allegroai/clearml.git" or "~/project/repo" or ""
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used).
Use an empty string to clear the branch.
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used).
Use an empty string to clear the commit.
"""
if running_remotely():
return
self._wait_for_repo_detection(timeout=300.)
with self._edit_lock:
self.reload()
if repo is not None:
# we cannot have None on the value itself
self.data.script.repository = repo or ""
if branch is not None:
# we cannot have None on the value itself
self.data.script.branch = branch or ""
if commit is not None:
# we cannot have None on the value itself
self.data.script.version_num = commit or ""
self._edit(script=self.data.script)
def connect_configuration(self, configuration, name=None, description=None):
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str]) -> Union[dict, Path, str]
"""
@@ -2015,6 +2094,8 @@ class Task(_Task):
# unregister atexit callbacks and signal hooks, if we are the main task
if is_main:
self.__register_at_exit(None)
self._remove_signal_hooks()
self._remove_exception_hooks()
if not is_sub_process:
# make sure we enable multiple Task.init callas with reporting sub-processes
BackgroundMonitor.clear_main_process(self)
@@ -2715,41 +2796,6 @@ class Task(_Task):
docker_setup_bash_script=docker_setup_bash_script
)
def set_packages(self, packages):
# type: (Union[str, Sequence[str]]) -> ()
"""
Manually specify a list of required packages or a local requirements.txt file.
When running remotely the call is ignored
:param packages: The list of packages or the path to the requirements.txt file.
Example: ["tqdm>=2.1", "scikit-learn"] or "./requirements.txt"
"""
if running_remotely():
return
super(Task, self).set_packages(packages)
def set_repo(self, repo, branch=None, commit=None):
# type: (str, Optional[str], Optional[str]) -> ()
"""
Specify a repository to attach to the function.
Allow users to execute the task inside the specified repository, enabling them to load modules/script
from the repository. Notice the execution work directory will be the repository root folder.
Supports both git repo url link, and local repository path (automatically converted into the remote
git/commit as is currently checkout).
Example remote url: 'https://github.com/user/repo.git'.
Example local repo copy: './repo' -> will automatically store the remote
repo url and commit ID based on the locally cloned copy.
When executing remotely, this call will not override the repository data (it is ignored)
:param repo: Remote URL for the repository to use, OR path to local copy of the git repository
Example: 'https://github.com/allegroai/clearml.git' or '~/project/repo'
:param branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
"""
if running_remotely():
return
super(Task, self).set_repo(repo, branch=branch, commit=commit)
def set_resource_monitor_iteration_timeout(self, seconds_from_start=1800):
# type: (float) -> bool
"""
@@ -4192,172 +4238,21 @@ class Task(_Task):
if not is_sub_process and BackgroundMonitor.is_subprocess_enabled():
BackgroundMonitor.wait_for_sub_process(self)
# we are done
return
@classmethod
def __register_at_exit(cls, exit_callback, only_remove_signal_and_exception_hooks=False):
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
remote_user_aborted = False
def _remove_exception_hooks(cls):
if cls.__exit_hook:
cls.__exit_hook.remove_exception_hooks()
def __init__(self, callback):
self.exit_code = None
self.exception = None
self.signal = None
self._exit_callback = callback
self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
def update_callback(self, callback):
if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try:
atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback
if callback:
self.hook()
else:
# un register int hook
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for h in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(h, self._org_handlers[h])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
self._orig_exit = sys.exit
sys.exit = self.exit
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
if self._exit_callback:
atexit.register(self._exit_callback)
# TODO: check if sub-process hooks are safe enough, for the time being allow it
if not self._org_handlers: # ## and not Task._Task__is_subprocess():
if sys.platform == 'win32':
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE]
else:
catch_signals = [signal.SIGINT, signal.SIGTERM, signal.SIGSEGV, signal.SIGABRT,
signal.SIGILL, signal.SIGFPE, signal.SIGQUIT]
for c in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[c] = signal.getsignal(c)
signal.signal(c, self.signal_handler)
except Exception:
pass
def exit(self, code=0):
self.exit_code = code
self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
try:
# remove us from import errors
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
prev = cur = traceback
while cur is not None:
tb_next = cur.tb_next
# if this is the import frame, we should remove it
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
# remove this frame by connecting the previous one to the next one
prev.tb_next = tb_next
cur.tb_next = None
del cur
cur = prev
prev = cur
cur = tb_next
except: # noqa
pass
if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame):
self.signal = sig
org_handler = self._org_handlers.get(sig)
signal.signal(sig, org_handler or signal.SIG_DFL)
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
if sig == signal.SIGINT:
# return original handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
if self._signal_recursion_protection_flag:
# call original
os.kill(os.getpid(), sig)
return org_handler if not callable(org_handler) else org_handler(sig, frame)
self._signal_recursion_protection_flag = True
# call exit callback
if self._exit_callback:
# noinspection PyBroadException
try:
self._exit_callback()
except Exception:
pass
# remove stdout logger, just in case
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
Logger._remove_std_logger()
except Exception:
pass
# noinspection PyUnresolvedReferences
os.kill(os.getpid(), sig)
self._signal_recursion_protection_flag = False
# return handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
# we only remove the signals since this will hang subprocesses
if only_remove_signal_and_exception_hooks:
if not cls.__exit_hook:
return
if cls.__exit_hook._orig_exc_handler:
sys.excepthook = cls.__exit_hook._orig_exc_handler
cls.__exit_hook._orig_exc_handler = None
for s in cls.__exit_hook._org_handlers:
# noinspection PyBroadException
try:
signal.signal(s, cls.__exit_hook._org_handlers[s])
except Exception:
pass
cls.__exit_hook._org_handlers = {}
return
@classmethod
def _remove_signal_hooks(cls):
if cls.__exit_hook:
cls.__exit_hook.remove_signal_hooks()
@classmethod
def __register_at_exit(cls, exit_callback):
if cls.__exit_hook is None:
# noinspection PyBroadException
try:
@@ -4368,12 +4263,6 @@ class Task(_Task):
else:
cls.__exit_hook.update_callback(exit_callback)
def _remove_at_exit_callbacks(self):
self.__register_at_exit(None, only_remove_signal_and_exception_hooks=True)
# noinspection PyProtectedMember
atexit.unregister(self.__exit_hook._exit_callback)
self._at_exit_called = True
@classmethod
def __get_task(
cls,

View File

@@ -0,0 +1,336 @@
#
# distutils/version.py
#
# Implements multiple version numbering conventions for the
# Python Module Distribution Utilities.
#
# $Id$
#
"""Provides classes to represent module version numbers (one class for
each style of version numbering). There are currently two such classes
implemented: StrictVersion and LooseVersion.
Every version number class implements the following interface:
* the 'parse' method takes a string and parses it to some internal
representation; if the string is an invalid version number,
'parse' raises a ValueError exception
* the class constructor takes an optional string argument which,
if supplied, is passed to 'parse'
* __str__ reconstructs the string that was passed to 'parse' (or
an equivalent string -- ie. one that will generate an equivalent
version number instance)
* __repr__ generates Python code to recreate the version number instance
* _cmp compares the current instance with either another instance
of the same class or a string (which will be parsed to an instance
of the same class, thus must follow the same rules)
"""
import re
class Version:
"""Abstract base class for version numbering classes. Just provides
constructor (__init__) and reproducer (__repr__), because those
seem to be the same for all version numbering classes; and route
rich comparisons to _cmp.
"""
def __init__(self, vstring=None):
if vstring:
self.parse(vstring)
def __repr__(self):
return "%s ('%s')" % (self.__class__.__name__, str(self))
def __eq__(self, other):
c = self._cmp(other)
if c is NotImplemented:
return c
return c == 0
def __lt__(self, other):
c = self._cmp(other)
if c is NotImplemented:
return c
return c < 0
def __le__(self, other):
c = self._cmp(other)
if c is NotImplemented:
return c
return c <= 0
def __gt__(self, other):
c = self._cmp(other)
if c is NotImplemented:
return c
return c > 0
def __ge__(self, other):
c = self._cmp(other)
if c is NotImplemented:
return c
return c >= 0
# Interface for version-number classes -- must be implemented
# by the following classes (the concrete ones -- Version should
# be treated as an abstract class).
# __init__ (string) - create and take same action as 'parse'
# (string parameter is optional)
# parse (string) - convert a string representation to whatever
# internal representation is appropriate for
# this style of version numbering
# __str__ (self) - convert back to a string; should be very similar
# (if not identical to) the string supplied to parse
# __repr__ (self) - generate Python code to recreate
# the instance
# _cmp (self, other) - compare two version numbers ('other' may
# be an unparsed version string, or another
# instance of your version class)
class StrictVersion(Version):
"""Version numbering for anal retentives and software idealists.
Implements the standard interface for version number classes as
described above. A version number consists of two or three
dot-separated numeric components, with an optional "pre-release" tag
on the end. The pre-release tag consists of the letter 'a' or 'b'
followed by a number. If the numeric components of two version
numbers are equal, then one with a pre-release tag will always
be deemed earlier (lesser) than one without.
The following are valid version numbers (shown in the order that
would be obtained by sorting according to the supplied cmp function):
0.4 0.4.0 (these two are equivalent)
0.4.1
0.5a1
0.5b3
0.5
0.9.6
1.0
1.0.4a3
1.0.4b1
1.0.4
The following are examples of invalid version numbers:
1
2.7.2.2
1.3.a4
1.3pl1
1.3c4
The rationale for this version numbering system will be explained
in the distutils documentation.
"""
version_re = re.compile(r"^(\d+) \. (\d+) (\. (\d+))? ([ab](\d+))?$", re.VERBOSE | re.ASCII)
def parse(self, vstring):
match = self.version_re.match(vstring)
if not match:
raise ValueError("invalid version number '%s'" % vstring)
(major, minor, patch, prerelease, prerelease_num) = match.group(1, 2, 4, 5, 6)
if patch:
self.version = tuple(map(int, [major, minor, patch]))
else:
self.version = tuple(map(int, [major, minor])) + (0,)
if prerelease:
self.prerelease = (prerelease[0], int(prerelease_num))
else:
self.prerelease = None
def __str__(self):
if self.version[2] == 0:
vstring = ".".join(map(str, self.version[0:2]))
else:
vstring = ".".join(map(str, self.version))
if self.prerelease:
vstring = vstring + self.prerelease[0] + str(self.prerelease[1])
return vstring
def _cmp(self, other):
if isinstance(other, str):
other = StrictVersion(other)
if self.version != other.version:
# numeric versions don't match
# prerelease stuff doesn't matter
if self.version < other.version:
return -1
else:
return 1
# have to compare prerelease
# case 1: neither has prerelease; they're equal
# case 2: self has prerelease, other doesn't; other is greater
# case 3: self doesn't have prerelease, other does: self is greater
# case 4: both have prerelease: must compare them!
if not self.prerelease and not other.prerelease:
return 0
elif self.prerelease and not other.prerelease:
return -1
elif not self.prerelease and other.prerelease:
return 1
elif self.prerelease and other.prerelease:
if self.prerelease == other.prerelease:
return 0
elif self.prerelease < other.prerelease:
return -1
else:
return 1
else:
assert False, "never get here"
# end class StrictVersion
# The rules according to Greg Stein:
# 1) a version number has 1 or more numbers separated by a period or by
# sequences of letters. If only periods, then these are compared
# left-to-right to determine an ordering.
# 2) sequences of letters are part of the tuple for comparison and are
# compared lexicographically
# 3) recognize the numeric components may have leading zeroes
#
# The LooseVersion class below implements these rules: a version number
# string is split up into a tuple of integer and string components, and
# comparison is a simple tuple comparison. This means that version
# numbers behave in a predictable and obvious way, but a way that might
# not necessarily be how people *want* version numbers to behave. There
# wouldn't be a problem if people could stick to purely numeric version
# numbers: just split on period and compare the numbers as tuples.
# However, people insist on putting letters into their version numbers;
# the most common purpose seems to be:
# - indicating a "pre-release" version
# ('alpha', 'beta', 'a', 'b', 'pre', 'p')
# - indicating a post-release patch ('p', 'pl', 'patch')
# but of course this can't cover all version number schemes, and there's
# no way to know what a programmer means without asking him.
#
# The problem is what to do with letters (and other non-numeric
# characters) in a version number. The current implementation does the
# obvious and predictable thing: keep them as strings and compare
# lexically within a tuple comparison. This has the desired effect if
# an appended letter sequence implies something "post-release":
# eg. "0.99" < "0.99pl14" < "1.0", and "5.001" < "5.001m" < "5.002".
#
# However, if letters in a version number imply a pre-release version,
# the "obvious" thing isn't correct. Eg. you would expect that
# "1.5.1" < "1.5.2a2" < "1.5.2", but under the tuple/lexical comparison
# implemented here, this just isn't so.
#
# Two possible solutions come to mind. The first is to tie the
# comparison algorithm to a particular set of semantic rules, as has
# been done in the StrictVersion class above. This works great as long
# as everyone can go along with bondage and discipline. Hopefully a
# (large) subset of Python module programmers will agree that the
# particular flavour of bondage and discipline provided by StrictVersion
# provides enough benefit to be worth using, and will submit their
# version numbering scheme to its domination. The free-thinking
# anarchists in the lot will never give in, though, and something needs
# to be done to accommodate them.
#
# Perhaps a "moderately strict" version class could be implemented that
# lets almost anything slide (syntactically), and makes some heuristic
# assumptions about non-digits in version number strings. This could
# sink into special-case-hell, though; if I was as talented and
# idiosyncratic as Larry Wall, I'd go ahead and implement a class that
# somehow knows that "1.2.1" < "1.2.2a2" < "1.2.2" < "1.2.2pl3", and is
# just as happy dealing with things like "2g6" and "1.13++". I don't
# think I'm smart enough to do it right though.
#
# In any case, I've coded the test suite for this module (see
# ../test/test_version.py) specifically to fail on things like comparing
# "1.2a2" and "1.2". That's not because the *code* is doing anything
# wrong, it's because the simple, obvious design doesn't match my
# complicated, hairy expectations for real-world version numbers. It
# would be a snap to fix the test suite to say, "Yep, LooseVersion does
# the Right Thing" (ie. the code matches the conception). But I'd rather
# have a conception that matches common notions about version numbers.
class LooseVersion(Version):
"""Version numbering for anarchists and software realists.
Implements the standard interface for version number classes as
described above. A version number consists of a series of numbers,
separated by either periods or strings of letters. When comparing
version numbers, the numeric components will be compared
numerically, and the alphabetic components lexically. The following
are all valid version numbers, in no particular order:
1.5.1
1.5.2b2
161
3.10a
8.02
3.4j
1996.07.12
3.2.pl0
3.1.1.6
2g6
11g
0.960923
2.2beta29
1.13++
5.5.kw
2.0b1pl0
In fact, there is no such thing as an invalid version number under
this scheme; the rules for comparison are simple and predictable,
but may not always give the results you want (for some definition
of "want").
"""
component_re = re.compile(r"(\d+ | [a-z]+ | \.)", re.VERBOSE)
def __init__(self, vstring=None):
if vstring:
self.parse(vstring)
def parse(self, vstring):
# I've given up on thinking I can reconstruct the version string
# from the parsed tuple -- so I just store the string here for
# use by __str__
self.vstring = vstring
components = [x for x in self.component_re.split(vstring) if x and x != "."]
for i, obj in enumerate(components):
try:
components[i] = int(obj)
except ValueError:
pass
self.version = components
def __str__(self):
return self.vstring
def __repr__(self):
return "LooseVersion ('%s')" % str(self)
def _cmp(self, other):
if isinstance(other, str):
other = LooseVersion(other)
if self.version == other.version:
return 0
if self.version < other.version:
return -1
if self.version > other.version:
return 1
# end class LooseVersion

View File

@@ -0,0 +1,96 @@
import os
from logging import getLogger
from time import sleep, time
from pathlib2 import Path
def get_torch_local_rank():
"""
return the local rank of the process, notice local rank 0 does not mean global rank 0
return None if no torch distributed is running
"""
if os.environ.get("TORCHELASTIC_RUN_ID") is not None:
# noinspection PyBroadException
try:
return int(os.environ.get("LOCAL_RANK"))
except Exception:
return None
return None
def create_torch_distributed_anchor(task_id):
"""
This will create a temporary file to pass the Task ID created by local_rank 0 of
if None local rank 0 is calling this file, it
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
"""
local_file_name = ".clearml_torch_distributed_id"
if get_torch_local_rank() != 0:
return
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent
# create the file
with open(torch_dist_path / local_file_name, "wt") as f:
f.write(str(task_id)+"\n")
except Exception:
# we failed for some reason?
getLogger().warning("Failed creating torch task ID anchor file: {}".format(torch_dist_path))
def get_torch_distributed_anchor_task_id(timeout=None):
"""
This will wait until a temporary file appears and read the Task ID created by local_rank 0 of
Only call when running locally (i.e. without an agent),
if running remotely there is no need to pass Task ID, it will be passed externally
:return Task ID of the local task to report to
"""
# check that we are not local rank 0
_local_rank = get_torch_local_rank()
if not _local_rank:
return
local_file_name = ".clearml_torch_distributed_id"
torch_dist_path = os.environ.get("TORCHELASTIC_ERROR_FILE")
if not torch_dist_path:
return
task_id = None
# noinspection PyBroadException
try:
torch_dist_path = Path(torch_dist_path).parent.parent.parent / local_file_name
tic = time()
# wait until disturbed file exists
while not torch_dist_path.is_file():
# if we found nothing, return None
if timeout is not None and time() - tic > timeout:
getLogger().warning("Failed detecting rank zero clearml Task ID, creating a new Task")
return None
# wait
sleep(0.25)
# create the file
with open(torch_dist_path, "rt") as f:
task_id = f.read().strip(" \n")
except Exception:
# we failed for some reason?
pass
getLogger().warning("Torch Distributed Local Rank {} Task ID {} detected".format(_local_rank, task_id))
return task_id

View File

@@ -1,7 +1,7 @@
import warnings
import itertools
from contextlib import contextmanager
from distutils.version import LooseVersion
from clearml.utilities.distutils_version import LooseVersion
import numpy as np
import matplotlib as mpl

View File

@@ -0,0 +1,187 @@
import atexit
import os
import signal
import sys
import six
from ...logger import Logger
class ExitHooks(object):
_orig_exit = None
_orig_exc_handler = None
remote_user_aborted = False
def __init__(self, callback):
self.exit_code = None
self.exception = None
self.signal = None
self._exit_callback = callback
self._org_handlers = {}
self._signal_recursion_protection_flag = False
self._except_recursion_protection_flag = False
self._import_bind_path = os.path.join("clearml", "binding", "import_bind.py")
def update_callback(self, callback):
if self._exit_callback and not six.PY2:
# noinspection PyBroadException
try:
atexit.unregister(self._exit_callback)
except Exception:
pass
self._exit_callback = callback
if callback:
self.hook()
else:
# un register int hook
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
for h in self._org_handlers:
# noinspection PyBroadException
try:
signal.signal(h, self._org_handlers[h])
except Exception:
pass
self._org_handlers = {}
def hook(self):
if self._orig_exit is None:
self._orig_exit = sys.exit
sys.exit = self.exit
if self._exit_callback:
atexit.register(self._exit_callback)
def register_signal_and_exception_hooks(self):
if self._orig_exc_handler is None:
self._orig_exc_handler = sys.excepthook
sys.excepthook = self.exc_handler
if not self._org_handlers:
if sys.platform == "win32":
catch_signals = [
signal.SIGINT,
signal.SIGTERM,
signal.SIGSEGV,
signal.SIGABRT,
signal.SIGILL,
signal.SIGFPE,
]
else:
catch_signals = [
signal.SIGINT,
signal.SIGTERM,
signal.SIGSEGV,
signal.SIGABRT,
signal.SIGILL,
signal.SIGFPE,
signal.SIGQUIT,
]
for c in catch_signals:
# noinspection PyBroadException
try:
self._org_handlers[c] = signal.getsignal(c)
signal.signal(c, self.signal_handler)
except Exception:
pass
def remove_signal_hooks(self):
for org_handler_k, org_handler_v in self._org_handlers.items():
# noinspection PyBroadException
try:
signal.signal(org_handler_k, org_handler_v)
except Exception:
pass
self._org_handlers = {}
def remove_exception_hooks(self):
if self._orig_exc_handler:
sys.excepthook = self._orig_exc_handler
self._orig_exc_handler = None
def exit(self, code=0):
self.exit_code = code
self._orig_exit(code)
def exc_handler(self, exctype, value, traceback, *args, **kwargs):
if self._except_recursion_protection_flag or not self._orig_exc_handler:
# noinspection PyArgumentList
return sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = True
self.exception = value
try:
# remove us from import errors
if six.PY3 and isinstance(exctype, type) and issubclass(exctype, ImportError):
prev = cur = traceback
while cur is not None:
tb_next = cur.tb_next
# if this is the import frame, we should remove it
if cur.tb_frame.f_code.co_filename.endswith(self._import_bind_path):
# remove this frame by connecting the previous one to the next one
prev.tb_next = tb_next
cur.tb_next = None
del cur
cur = prev
prev = cur
cur = tb_next
except: # noqa
pass
if self._orig_exc_handler:
# noinspection PyArgumentList
ret = self._orig_exc_handler(exctype, value, traceback, *args, **kwargs)
else:
# noinspection PyNoneFunctionAssignment, PyArgumentList
ret = sys.__excepthook__(exctype, value, traceback, *args, **kwargs)
self._except_recursion_protection_flag = False
return ret
def signal_handler(self, sig, frame):
org_handler = self._org_handlers.get(sig)
if not org_handler:
return signal.SIG_DFL
self.signal = sig
signal.signal(sig, org_handler or signal.SIG_DFL)
# if this is a sig term, we wait until __at_exit is called (basically do nothing)
if sig == signal.SIGINT:
# return original handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)
if self._signal_recursion_protection_flag:
# call original
os.kill(os.getpid(), sig)
return org_handler if not callable(org_handler) else org_handler(sig, frame)
self._signal_recursion_protection_flag = True
# call exit callback
if self._exit_callback:
# noinspection PyBroadException
try:
self._exit_callback()
except Exception:
pass
# remove stdout logger, just in case
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
Logger._remove_std_logger()
except Exception:
pass
# noinspection PyUnresolvedReferences
os.kill(os.getpid(), sig)
self._signal_recursion_protection_flag = False
# return handler result
return org_handler if not callable(org_handler) else org_handler(sig, frame)

View File

@@ -541,6 +541,15 @@ class BackgroundMonitor(object):
self._event.set()
if isinstance(self._thread, Thread):
# should we wait for the thread to finish
# noinspection PyBroadException
try:
# there is a race here, and if someone else closes the
# thread it can become True/None and we will fail, it is fine
self._thread.join()
except BaseException:
pass
try:
self._get_instances().remove(self)
except ValueError:
@@ -669,21 +678,23 @@ class BackgroundMonitor(object):
@classmethod
def _background_process_start(cls, task_obj_id, event_start=None, parent_pid=None):
# type: (int, Optional[SafeEvent], Optional[int]) -> None
# noinspection PyProtectedMember
is_debugger_running = bool(getattr(sys, 'gettrace', None) and sys.gettrace())
# make sure we update the pid to our own
cls._main_process = os.getpid()
cls._main_process_proc_obj = psutil.Process(cls._main_process)
# restore original signal, this will prevent any deadlocks
# Do not change the exception we need to catch base exception as well
# noinspection PyBroadException
try:
from ... import Task
# make sure we do not call Task.current_task() it will create a Task object for us on a subprocess!
if Task._Task__current_task and Task._Task__current_task._Task__exit_hook: # noqa
Task._Task__current_task._Task__exit_hook.register_signal_and_exception_hooks() # noqa
# noinspection PyProtectedMember
if Task._has_current_task_obj():
# noinspection PyProtectedMember
Task.current_task()._remove_at_exit_callbacks()
from ...binding.environ_bind import PatchOsFork
PatchOsFork.unpatch_fork()
PatchOsFork.unpatch_process_run()
except: # noqa
# Do not change the exception we need to catch base exception as well
pass
# if a debugger is running, wait for it to attach to the subprocess

View File

@@ -1 +1 @@
__version__ = '1.13.2'
__version__ = '1.14.1'

View File

@@ -0,0 +1,111 @@
import sys
from clearml import Task
from clearml.automation import DiscreteParameterRange, HyperParameterOptimizer, UniformIntegerParameterRange
try:
from clearml.automation.optuna import OptimizerOptuna # noqa
except ImportError:
print("Multi-objective HPO is currently only supported via Optuna")
sys.exit(0)
if __name__ == "__main__":
task = Task.init(
project_name="Hyper-Parameter Optimization",
task_name="Multi-objective HPO",
task_type=Task.TaskTypes.optimizer,
reuse_last_task_id=False,
)
# experiment template to optimize in the hyper-parameter optimization
args = {
"template_task_id": None,
"run_as_service": False,
}
args = task.connect(args)
# Get the template task experiment that we want to optimize
if not args["template_task_id"]:
args["template_task_id"] = Task.get_task(project_name="examples", task_name="Keras HP optimization base").id
# Set default queue name for the Training tasks themselves.
# later can be overridden in the UI
execution_queue = "1xGPU"
an_optimizer = HyperParameterOptimizer(
# This is the experiment we want to optimize
base_task_id=args["template_task_id"],
# here we define the hyper-parameters to optimize
# Notice: The parameter name should exactly match what you see in the UI: <section_name>/<parameter>
# For Example, here we see in the base experiment a section Named: "General"
# under it a parameter named "batch_size", this becomes "General/batch_size"
# If you have `argparse` for example, then arguments will appear under the "Args" section,
# and you should instead pass "Args/batch_size"
hyper_parameters=[
UniformIntegerParameterRange("General/layer_1", min_value=128, max_value=512, step_size=128),
UniformIntegerParameterRange("General/layer_2", min_value=128, max_value=512, step_size=128),
DiscreteParameterRange("General/batch_size", values=[96, 128, 160]),
DiscreteParameterRange("General/epochs", values=[30]),
],
# this is the objectives' metric/series we want to maximize/minimize
objective_metric_title=["evaluate", "evaluate"],
objective_metric_series=["score", "accuracy"],
# now we decide if we want to maximize it or minimize them
# in this case, we want to minimize evaluate/score and maximize evaluate/accuracy
objective_metric_sign=["min", "max"],
# let us limit the number of concurrent experiments,
# this in turn will make sure we do dont bombard the scheduler with experiments.
# if we have an auto-scaler connected, this, by proxy, will limit the number of machine
max_number_of_concurrent_tasks=1,
# optimizer_class has to be OptimizerOptuna
optimizer_class=OptimizerOptuna,
# Select an execution queue to schedule the experiments for execution
execution_queue=execution_queue,
# If specified all Tasks created by the HPO process will be created under the `spawned_project` project
spawn_project=None, # 'HPO spawn project',
# If specified only the top K performing Tasks will be kept, the others will be automatically archived
save_top_k_tasks_only=None, # 5,
# Optional: Limit the execution time of a single experiment, in minutes.
# (this is optional, and if using OptimizerBOHB, it is ignored)
time_limit_per_job=10.0,
# Check the experiments every 12 seconds is way too often, we should probably set it to 5 min,
# assuming a single experiment is usually hours...
pool_period_min=0.2,
# set the maximum number of jobs to launch for the optimization, default (None) unlimited
# If OptimizerBOHB is used, it defined the maximum budget in terms of full jobs
# basically the cumulative number of iterations will not exceed total_max_jobs * max_iteration_per_job
total_max_jobs=10,
# set the minimum number of iterations for an experiment, before early stopping.
# Does not apply for simple strategies such as RandomSearch or GridSearch
min_iteration_per_job=10,
# Set the maximum number of iterations for an experiment to execute
# (This is optional, unless using OptimizerBOHB where this is a must)
max_iteration_per_job=30,
)
# if we are running as a service, just enqueue ourselves into the services queue and let it run the optimization
if args["run_as_service"]:
# if this code is executed by `clearml-agent` the function call does nothing.
# if executed locally, the local process will be terminated, and a remote copy will be executed instead
task.execute_remotely(queue_name="services", exit_process=True)
# report every 12 seconds, this is way too often, but we are testing here
an_optimizer.set_report_period(0.2)
# start the optimization process, callback function to be called every time an experiment is completed
# this function returns immediately
an_optimizer.start()
# You can also use the line below instead to run all the optimizer tasks locally, without using queues or agent
# an_optimizer.start_locally(job_complete_callback=job_complete_callback)
# set the time limit for the optimization process (2 hours)
an_optimizer.set_time_limit(in_minutes=120.0)
# wait until process is done (notice we are controlling the optimization process in the background)
an_optimizer.wait()
# optimization is completed, print the top performing experiments id
top_exp = an_optimizer.get_top_experiments(top_k=3)
print([t.id for t in top_exp])
# make sure background optimization stopped
an_optimizer.stop()
print("We are done, good bye")

View File

@@ -0,0 +1,24 @@
from clearml import PipelineDecorator
def our_decorator(func):
def function_wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return function_wrapper
@PipelineDecorator.component()
@our_decorator
def step():
return 1
@PipelineDecorator.pipeline(name="test_decorated", project="test_decorated")
def pipeline():
result = step()
assert result == 2
if __name__ == "__main__":
PipelineDecorator.run_locally()
pipeline()

View File

@@ -0,0 +1,27 @@
from clearml import PipelineController
def our_decorator(func):
def function_wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return function_wrapper
@our_decorator
def step():
return 1
def evaluate(step_return):
assert step_return == 2
if __name__ == "__main__":
pipeline = PipelineController(name="test_decorated", project="test_decorated")
pipeline.add_function_step(name="step", function=step, function_return=["step_return"])
pipeline.add_function_step(
name="evaluate",
function=evaluate,
function_kwargs=dict(step_return='${step.step_return}')
)
pipeline.start_locally(run_pipeline_steps_locally=True)

View File

@@ -0,0 +1,179 @@
from clearml import PipelineDecorator, Task
@PipelineDecorator.component(cache=True)
def create_dataset(source_url: str, project: str, dataset_name: str) -> str:
print("starting create_dataset")
from clearml import StorageManager, Dataset
import pandas as pd
local_file = StorageManager.get_local_copy(source_url)
df = pd.read_csv(local_file, header=None)
df.to_csv(path_or_buf="./dataset.csv", index=False)
dataset = Dataset.create(dataset_project=project, dataset_name=dataset_name)
dataset.add_files("./dataset.csv")
dataset.get_logger().report_table(title="sample", series="head", table_plot=df.head())
dataset.finalize(auto_upload=True)
print("done create_dataset")
return dataset.id
@PipelineDecorator.component(cache=True)
def preprocess_dataset(dataset_id: str):
print("starting preprocess_dataset")
from clearml import Dataset
from pathlib import Path
import pandas as pd
dataset = Dataset.get(dataset_id=dataset_id)
local_folder = dataset.get_local_copy()
df = pd.read_csv(Path(local_folder) / "dataset.csv", header=None)
# "preprocessing" - adding columns
df.columns = [
'age', 'workclass', 'fnlwgt', 'degree', 'education-yrs', 'marital-status',
'occupation', 'relationship', 'ethnicity', 'gender', 'capital-gain',
'capital-loss', 'hours-per-week', 'native-country', 'income-cls',
]
df.to_csv(path_or_buf="./dataset.csv", index=False)
# store in a new dataset
new_dataset = Dataset.create(
dataset_project=dataset.project, dataset_name="{} v2".format(dataset.name),
parent_datasets=[dataset]
)
new_dataset.add_files("./dataset.csv")
new_dataset.get_logger().report_table(title="sample", series="head", table_plot=df.head())
new_dataset.finalize(auto_upload=True)
print("done preprocess_dataset")
return new_dataset.id
@PipelineDecorator.component(cache=True)
def verify_dataset_integrity(dataset_id: str, expected_num_columns: int):
print("starting verify_dataset_integrity")
from clearml import Dataset, Logger
from pathlib import Path
import numpy as np
import pandas as pd
dataset = Dataset.get(dataset_id=dataset_id)
local_folder = dataset.get_local_copy()
df = pd.read_csv(Path(local_folder) / "dataset.csv")
print("Verifying dataset")
assert len(df.columns) == expected_num_columns
print("PASSED")
# log some stats on the age column
Logger.current_logger().report_histogram(
title="histogram", series="age", values=np.histogram(df["age"])
)
print("done verify_dataset_integrity")
return True
@PipelineDecorator.component(output_uri=True)
def train_model(dataset_id: str, training_args: dict):
print("starting train_model")
from clearml import Dataset, OutputModel, Task
from pathlib import Path
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
dataset = Dataset.get(dataset_id=dataset_id)
local_folder = dataset.get_local_copy()
df = pd.read_csv(Path(local_folder) / "dataset.csv")
# prepare data (i.e. select specific columns)
columns = ["age", "fnlwgt", "education-yrs", "capital-gain", "capital-loss", "hours-per-week"]
X = df[columns].drop("age", axis=1)
y = df["age"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# create matrix
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
# train with XGBoost
params = {"objective": "reg:squarederror", "eval_metric": "rmse"}
bst = xgb.train(
params,
dtrain,
num_boost_round=training_args.get("num_boost_round", 100),
evals=[(dtrain, "train"), (dtest, "test")],
verbose_eval=0,
)
# evaluate
y_pred = bst.predict(dtest)
plt.plot(y_test, 'r')
plt.plot(y_pred, 'b')
# let's store the eval score
error = np.linalg.norm(y_test-y_pred)
bst.save_model("a_model.xgb")
Task.current_task().reload()
model_id = Task.current_task().models['output'][-1].id
print("done train_model")
return dict(error=error, model_id=model_id)
@PipelineDecorator.component(monitor_models=["best"])
def select_best_model(models_score: list):
print("starting select_best_model:", models_score)
from clearml import OutputModel, Task
best_model = None
for m in models_score:
if not best_model or m["error"] < best_model["error"]:
best_model = m
print("The best model is {}".format(best_model))
# lets store it on the pipeline
best_model = OutputModel(base_model_id=best_model["model_id"])
# let's make sure we have it
best_model.connect(task=Task.current_task(), name="best")
print("done select_best_model")
return best_model.id
@PipelineDecorator.pipeline(
name='xgboost_pipeline',
project='xgboost_pipe_demo',
version='0.1'
)
def pipeline(data_url: str, project: str):
dataset_id = create_dataset(source_url=data_url, project=project, dataset_name="mock")
preprocessed_dataset_id = preprocess_dataset(dataset_id=dataset_id)
if not bool(verify_dataset_integrity(
dataset_id=preprocessed_dataset_id,
expected_num_columns=15)
):
print("Verification Failed!")
return False
print("start training models")
models_score = []
for i in [100, 150]:
model_score = train_model(
dataset_id=preprocessed_dataset_id, training_args=dict(num_boost_round=i)
)
models_score.append(model_score)
model_id = select_best_model(models_score=models_score)
print("selected model_id = {}".format(model_id))
if __name__ == '__main__':
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
# comment to run the entire pipeline remotely
if Task.running_locally():
# this is for demonstration purpose only,
# it will run the entire pipeline logic and components locally
PipelineDecorator.run_locally()
pipeline(data_url=url, project="xgboost_pipe_demo")

View File

@@ -2,4 +2,5 @@ joblib>=0.13.2
matplotlib >= 3.1.1 ; python_version >= '3.6'
matplotlib >= 2.2.4 ; python_version < '3.6'
scikit-learn
pandas
clearml

View File

@@ -46,7 +46,7 @@ def report_plots(logger, iteration=0):
# report confusion matrix
confusion = np.random.randint(10, size=(10, 10))
logger.report_matrix(
logger.report_confusion_matrix(
"example_confusion",
"ignored",
iteration=iteration,
@@ -56,7 +56,7 @@ def report_plots(logger, iteration=0):
)
# report confusion matrix with 0,0 is at the top left
logger.report_matrix(
logger.report_confusion_matrix(
"example_confusion_0_0_at_top",
"ignored",
iteration=iteration,