Improve pipeline controller, add callbacks

This commit is contained in:
allegroai 2021-04-10 22:26:05 +03:00
parent cf28551d21
commit 0013c5851e
3 changed files with 229 additions and 64 deletions

View File

@ -1,7 +1,7 @@
from .parameters import UniformParameterRange, DiscreteParameterRange, UniformIntegerParameterRange, ParameterSet from .parameters import UniformParameterRange, DiscreteParameterRange, UniformIntegerParameterRange, ParameterSet
from .optimization import GridSearch, RandomSearch, HyperParameterOptimizer, Objective from .optimization import GridSearch, RandomSearch, HyperParameterOptimizer, Objective
from .job import TrainsJob from .job import ClearmlJob
from .controller import PipelineController from .controller import PipelineController
__all__ = ["UniformParameterRange", "DiscreteParameterRange", "UniformIntegerParameterRange", "ParameterSet", __all__ = ["UniformParameterRange", "DiscreteParameterRange", "UniformIntegerParameterRange", "ParameterSet",
"GridSearch", "RandomSearch", "HyperParameterOptimizer", "Objective", "TrainsJob", "PipelineController"] "GridSearch", "RandomSearch", "HyperParameterOptimizer", "Objective", "ClearmlJob", "PipelineController"]

View File

@ -8,9 +8,11 @@ from time import time
from attr import attrib, attrs from attr import attrib, attrs
from typing import Sequence, Optional, Mapping, Callable, Any, Union from typing import Sequence, Optional, Mapping, Callable, Any, Union
from ..backend_interface.util import get_or_create_project
from ..config import get_remote_task_id
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..task import Task from ..task import Task
from ..automation import TrainsJob from ..automation import ClearmlJob
from ..model import BaseModel from ..model import BaseModel
@ -34,9 +36,10 @@ class PipelineController(object):
parents = attrib(type=list, default=[]) parents = attrib(type=list, default=[])
timeout = attrib(type=float, default=None) timeout = attrib(type=float, default=None)
parameters = attrib(type=dict, default={}) parameters = attrib(type=dict, default={})
task_overrides = attrib(type=dict, default={})
executed = attrib(type=str, default=None) executed = attrib(type=str, default=None)
clone_task = attrib(type=bool, default=True) clone_task = attrib(type=bool, default=True)
job = attrib(type=TrainsJob, default=None) job = attrib(type=ClearmlJob, default=None)
skip_job = attrib(type=bool, default=False) skip_job = attrib(type=bool, default=False)
def __init__( def __init__(
@ -47,6 +50,9 @@ class PipelineController(object):
auto_connect_task=True, # type: Union[bool, Task] auto_connect_task=True, # type: Union[bool, Task]
always_create_task=False, # type: bool always_create_task=False, # type: bool
add_pipeline_tags=False, # type: bool add_pipeline_tags=False, # type: bool
target_project=None, # type: Optional[str]
pipeline_name=None, # type: Optional[str]
pipeline_project=None, # type: Optional[str]
): ):
# type: (...) -> () # type: (...) -> ()
""" """
@ -70,6 +76,11 @@ class PipelineController(object):
- ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics. - ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics.
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all :param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
steps (Tasks) created by this pipeline. steps (Tasks) created by this pipeline.
:param str target_project: If provided, all pipeline steps are cloned into the target project
:param pipeline_name: Optional, provide pipeline name if main Task is not present (default current date)
:param pipeline_project: Optional, provide project storing the pipeline if main Task is not present
""" """
self._nodes = {} self._nodes = {}
self._running_nodes = [] self._running_nodes = []
@ -81,14 +92,17 @@ class PipelineController(object):
self._stop_event = None self._stop_event = None
self._experiment_created_cb = None self._experiment_created_cb = None
self._experiment_completed_cb = None self._experiment_completed_cb = None
self._pre_step_callbacks = {}
self._post_step_callbacks = {}
self._target_project = target_project or ''
self._add_pipeline_tags = add_pipeline_tags self._add_pipeline_tags = add_pipeline_tags
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task() self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
self._step_ref_pattern = re.compile(self._step_pattern) self._step_ref_pattern = re.compile(self._step_pattern)
self._reporting_lock = RLock() self._reporting_lock = RLock()
if not self._task and always_create_task: if not self._task and always_create_task:
self._task = Task.init( self._task = Task.init(
project_name='Pipelines', project_name=pipeline_project or 'Pipelines',
task_name='Pipeline {}'.format(datetime.now()), task_name=pipeline_name or 'Pipeline {}'.format(datetime.now()),
task_type=Task.TaskTypes.controller, task_type=Task.TaskTypes.controller,
) )
@ -103,11 +117,15 @@ class PipelineController(object):
base_task_id=None, # type: Optional[str] base_task_id=None, # type: Optional[str]
parents=None, # type: Optional[Sequence[str]] parents=None, # type: Optional[Sequence[str]]
parameter_override=None, # type: Optional[Mapping[str, Any]] parameter_override=None, # type: Optional[Mapping[str, Any]]
task_overrides=None, # type: Optional[Mapping[str, Any]]
execution_queue=None, # type: Optional[str] execution_queue=None, # type: Optional[str]
time_limit=None, # type: Optional[float] time_limit=None, # type: Optional[float]
base_task_project=None, # type: Optional[str] base_task_project=None, # type: Optional[str]
base_task_name=None, # type: Optional[str] base_task_name=None, # type: Optional[str]
clone_base_task=True, # type: bool clone_base_task=True, # type: bool
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -131,6 +149,15 @@ class PipelineController(object):
parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' } parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' }
Task ID Task ID
parameter_override={'Args/input_file': '${stage3.id}' } parameter_override={'Args/input_file': '${stage3.id}' }
:param dict task_overrides: Optional task section overriding dictionary.
The dict values can reference a previously executed step using the following form '${step_name}'
Examples:
clear git repository commit ID
parameter_override={'script.version_num': '' }
git repository commit branch
parameter_override={'script.branch': '${stage1.script.branch}' }
container image
parameter_override={'container.image': '${stage1.container.image}' }
:param str execution_queue: Optional, the queue to use for executing this specific step. :param str execution_queue: Optional, the queue to use for executing this specific step.
If not provided, the task will be sent to the default execution queue, as defined on the class If not provided, the task will be sent to the default execution queue, as defined on the class
:param float time_limit: Default None, no time limit. :param float time_limit: Default None, no time limit.
@ -141,10 +168,49 @@ class PipelineController(object):
use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step. use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
:param bool clone_base_task: If True (default) the pipeline will clone the base task, and modify/enqueue :param bool clone_base_task: If True (default) the pipeline will clone the base task, and modify/enqueue
the cloned Task. If False, the base-task is used directly, notice it has to be in draft-mode (created). the cloned Task. If False, the base-task is used directly, notice it has to be in draft-mode (created).
:param Callable pre_execute_callback: Callback function, called when the step (Task) is created
and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the ClearmlJob.
If the callback returned value is `False`,
the Node is skipped and so is any node in the DAG that relies on this node.
Notice the `parameters` are already parsed,
e.g. `${step1.parameters.Args/param}` is replaced with relevant value.
.. code-block:: py
def step_created_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
parameters, # type: dict
):
pass
:param Callable post_execute_callback: Callback function, called when a step (Task) is completed
and it other jobs are executed. Allows a user to modify the Task status after completion.
.. code-block:: py
def step_completed_callback(
pipeline, # type: PipelineController,
node, # type: PipelineController.Node,
):
pass
:return: True if successful :return: True if successful
""" """
# always store callback functions (even when running remotely)
if pre_execute_callback:
self._pre_step_callbacks[name] = pre_execute_callback
if post_execute_callback:
self._post_step_callbacks[name] = post_execute_callback
# when running remotely do nothing, we will deserialize ourselves when we start # when running remotely do nothing, we will deserialize ourselves when we start
if self._has_stored_configuration(): # if we are not cloning a Task, we assume this step is created from code, not from the configuration
if clone_base_task and self._has_stored_configuration():
return True return True
if name in self._nodes: if name in self._nodes:
@ -168,13 +234,16 @@ class PipelineController(object):
queue=execution_queue, timeout=time_limit, queue=execution_queue, timeout=time_limit,
parameters=parameter_override or {}, parameters=parameter_override or {},
clone_task=clone_base_task, clone_task=clone_base_task,
task_overrides=task_overrides,
) )
if self._task and not self._task.running_locally():
self.update_execution_plot()
return True return True
def start( def start(
self, self,
run_remotely=False, # type: Union[bool, str]
step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
step_task_completed_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa step_task_completed_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
): ):
@ -183,13 +252,10 @@ class PipelineController(object):
Start the pipeline controller. Start the pipeline controller.
If the calling process is stopped, then the controller stops as well. If the calling process is stopped, then the controller stops as well.
:param bool run_remotely: (default False), If True stop the current process and continue execution
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'.
If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution.
:param Callable step_task_created_callback: Callback function, called when a step (Task) is created :param Callable step_task_created_callback: Callback function, called when a step (Task) is created
and before it is sent for execution. Allows a user to modify the Task before launch. and before it is sent for execution. Allows a user to modify the Task before launch.
Use `node.job` to access the TrainsJob object, or `node.job.task` to directly access the Task object. Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
`parameters` are the configuration arguments passed to the TrainsJob. `parameters` are the configuration arguments passed to the ClearmlJob.
If the callback returned value is `False`, If the callback returned value is `False`,
the Node is skipped and so is any node in the DAG that relies on this node. the Node is skipped and so is any node in the DAG that relies on this node.
@ -224,14 +290,8 @@ class PipelineController(object):
if self._thread: if self._thread:
return True return True
# serialize pipeline state params, pipeline_dag = self._serialize_pipeline_task()
pipeline_dag = self._serialize()
self._task.connect_configuration(pipeline_dag, name=self._config_section)
params = {'continue_pipeline': False,
'default_queue': self._default_execution_queue,
'add_pipeline_tags': self._add_pipeline_tags,
}
self._task.connect(params, name=self._config_section)
# deserialize back pipeline state # deserialize back pipeline state
if not params['continue_pipeline']: if not params['continue_pipeline']:
for k in pipeline_dag: for k in pipeline_dag:
@ -239,6 +299,7 @@ class PipelineController(object):
self._default_execution_queue = params['default_queue'] self._default_execution_queue = params['default_queue']
self._add_pipeline_tags = params['add_pipeline_tags'] self._add_pipeline_tags = params['add_pipeline_tags']
self._target_project = params['target_project'] or ''
self._deserialize(pipeline_dag) self._deserialize(pipeline_dag)
# if we continue the pipeline, make sure that we re-execute failed tasks # if we continue the pipeline, make sure that we re-execute failed tasks
@ -252,11 +313,6 @@ class PipelineController(object):
"it has either inaccessible nodes, or contains cycles") "it has either inaccessible nodes, or contains cycles")
self.update_execution_plot() self.update_execution_plot()
print('update_execution_plot!!!!')
if run_remotely:
self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
# we will not get here if we are not running remotely
self._start_time = time() self._start_time = time()
self._stop_event = Event() self._stop_event = Event()
@ -267,6 +323,38 @@ class PipelineController(object):
self._thread.start() self._thread.start()
return True return True
def start_remotely(self, queue='services', exit_process=True):
# type: (str, bool) -> Task
"""
Start the current pipeline remotely (on the selected services queue)
The current process will be stopped if exit_process is True.
:param queue: queue name to launch the pipeline on
:param exit_process: If True exit the current process after launching on the enqueuing on the queue
:return: The remote Task object
"""
if not self._task:
raise ValueError(
"Could not find main Task, "
"PipelineController must be created with `always_create_task=True`")
# serialize state only if we are running locally
if Task.running_locally() or not self._task.is_main_task():
self._serialize_pipeline_task()
self.update_execution_plot()
# stop current Task and execute remotely or no-op
self._task.execute_remotely(queue_name=queue, exit_process=exit_process, clone=False)
if not Task.running_locally() and self._task.is_main_task():
self.start()
self.wait()
self.stop()
exit(0)
else:
return self._task
def stop(self, timeout=None): def stop(self, timeout=None):
# type: (Optional[float]) -> () # type: (Optional[float]) -> ()
""" """
@ -337,7 +425,7 @@ class PipelineController(object):
{ {
'stage1' : Node() { 'stage1' : Node() {
name: 'stage1' name: 'stage1'
job: TrainsJob job: ClearmlJob
... ...
}, },
} }
@ -363,6 +451,27 @@ class PipelineController(object):
""" """
return {k: n for k, n in self._nodes.items() if k in self._running_nodes} return {k: n for k, n in self._nodes.items() if k in self._running_nodes}
def _serialize_pipeline_task(self):
# type: () -> (dict, dict)
"""
Serialize current pipeline state into the main Task
:return: params, pipeline_dag
"""
params = {'continue_pipeline': False,
'default_queue': self._default_execution_queue,
'add_pipeline_tags': self._add_pipeline_tags,
'target_project': self._target_project,
}
pipeline_dag = self._serialize()
# serialize pipeline state
if self._task:
self._task.connect_configuration(pipeline_dag, name=self._config_section)
self._task.connect(params, name=self._config_section)
return params, pipeline_dag
def _serialize(self): def _serialize(self):
# type: () -> dict # type: () -> dict
""" """
@ -382,7 +491,9 @@ class PipelineController(object):
This will be used to create the DAG from the dict stored on the Task, when running remotely. This will be used to create the DAG from the dict stored on the Task, when running remotely.
:return: :return:
""" """
self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()} self._nodes = {
k: self.Node(name=k, **v) if not v.get('clone_task') or k not in self._nodes else self._nodes[k]
for k, v in dag_dict.items()}
def _has_stored_configuration(self): def _has_stored_configuration(self):
""" """
@ -461,7 +572,7 @@ class PipelineController(object):
def _launch_node(self, node): def _launch_node(self, node):
# type: (PipelineController.Node) -> () # type: (PipelineController.Node) -> ()
""" """
Launch a single node (create and enqueue a TrainsJob) Launch a single node (create and enqueue a ClearmlJob)
:param node: Node to launch :param node: Node to launch
:return: Return True if a new job was launched :return: Return True if a new job was launched
@ -473,13 +584,31 @@ class PipelineController(object):
for k, v in node.parameters.items(): for k, v in node.parameters.items():
updated_hyper_parameters[k] = self._parse_step_ref(v) updated_hyper_parameters[k] = self._parse_step_ref(v)
node.job = TrainsJob( task_overrides = self._parse_task_overrides(node.task_overrides) if node.task_overrides else None
extra_args = dict()
if self._target_project:
extra_args['project'] = get_or_create_project(
session=self._task.session if self._task else Task.default_session,
project_name=self._target_project)
skip_node = None
if self._pre_step_callbacks.get(node.name):
skip_node = self._pre_step_callbacks[node.name](self, node, updated_hyper_parameters)
if skip_node is False:
node.skip_job = True
return True
node.job = ClearmlJob(
base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters, base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None, tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None,
parent=self._task.id if self._task else None, parent=self._task.id if self._task else None,
disable_clone_task=not node.clone_task, disable_clone_task=not node.clone_task,
task_overrides=task_overrides,
**extra_args
) )
skip_node = None
if self._experiment_created_cb: if self._experiment_created_cb:
skip_node = self._experiment_created_cb(self, node, updated_hyper_parameters) skip_node = self._experiment_created_cb(self, node, updated_hyper_parameters)
@ -508,6 +637,9 @@ class PipelineController(object):
""" """
Update sankey diagram of the current pipeline Update sankey diagram of the current pipeline
""" """
if not self._task:
return
sankey_node = dict( sankey_node = dict(
label=[], label=[],
color=[], color=[],
@ -618,6 +750,7 @@ class PipelineController(object):
def _force_task_configuration_update(self): def _force_task_configuration_update(self):
pipeline_dag = self._serialize() pipeline_dag = self._serialize()
if self._task:
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._set_configuration( self._task._set_configuration(
name=self._config_section, config_type='dictionary', config_dict=pipeline_dag) name=self._config_section, config_type='dictionary', config_dict=pipeline_dag)
@ -666,9 +799,15 @@ class PipelineController(object):
continue continue
# callback on completed jobs # callback on completed jobs
if self._experiment_completed_cb: if self._experiment_completed_cb or self._post_step_callbacks:
for job in completed_jobs: for job in completed_jobs:
self._experiment_completed_cb(self, job) job_node = self._nodes.get(job)
if not job_node:
continue
if self._experiment_completed_cb:
self._experiment_completed_cb(self, job_node)
if self._post_step_callbacks.get(job_node.name):
self._post_step_callbacks[job_node.name](self, job_node)
# Pull the next jobs in the pipeline, based on the completed list # Pull the next jobs in the pipeline, based on the completed list
next_nodes = [] next_nodes = []
@ -795,16 +934,24 @@ class PipelineController(object):
raise ValueError("Could not parse reference '{}'".format(step_ref_string)) raise ValueError("Could not parse reference '{}'".format(step_ref_string))
prev_step = parts[0] prev_step = parts[0]
input_type = parts[1].lower() input_type = parts[1].lower()
if prev_step not in self._nodes or not self._nodes[prev_step].job: if prev_step not in self._nodes or (
raise ValueError("Could not parse reference '{}', step {} could not be found".format( not self._nodes[prev_step].job and
not self._nodes[prev_step].executed and
not self._nodes[prev_step].base_task_id
):
raise ValueError("Could not parse reference '{}', step '{}' could not be found".format(
step_ref_string, prev_step)) step_ref_string, prev_step))
if input_type not in ('artifacts', 'parameters', 'models', 'id'):
raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type)) if input_type not in (
'artifacts', 'parameters', 'models', 'id',
'script', 'execution', 'container', 'output',
'comment', 'models', 'tags', 'system_tags', 'project'):
raise ValueError("Could not parse reference '{}', type '{}' not valid".format(step_ref_string, input_type))
if input_type != 'id' and len(parts) < 3: if input_type != 'id' and len(parts) < 3:
raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts)) raise ValueError("Could not parse reference '{}', missing fields in '{}'".format(step_ref_string, parts))
task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \ task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \
else Task.get_task(task_id=self._nodes[prev_step].executed) else Task.get_task(task_id=self._nodes[prev_step].executed or self._nodes[prev_step].base_task_id)
task.reload() task.reload()
if input_type == 'artifacts': if input_type == 'artifacts':
# fix \. to use . in artifacts # fix \. to use . in artifacts
@ -842,9 +989,14 @@ class PipelineController(object):
'.'.join(parts[1:]), prev_step, parts[3])) '.'.join(parts[1:]), prev_step, parts[3]))
return str(getattr(model, parts[4])) return str(getattr(model, parts[4]))
elif input_type == 'id': elif input_type == 'id':
return task.id return task.id
elif input_type in (
'script', 'execution', 'container', 'output',
'comment', 'models', 'tags', 'system_tags', 'project'):
# noinspection PyProtectedMember
return task._get_task_property('.'.join(parts[1:]))
return None return None
def _parse_step_ref(self, value): def _parse_step_ref(self, value):
@ -864,6 +1016,19 @@ class PipelineController(object):
updated_value = updated_value.replace(g, new_val, 1) updated_value = updated_value.replace(g, new_val, 1)
return updated_value return updated_value
def _parse_task_overrides(self, task_overrides):
# type: (dict) -> dict
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param task_overrides: string
:return:
"""
updated_overrides = {}
for k, v in task_overrides.items():
updated_overrides[k] = self._parse_step_ref(v)
return updated_overrides
@classmethod @classmethod
def __get_node_status(cls, a_node): def __get_node_status(cls, a_node):
# type: (PipelineController.Node) -> str # type: (PipelineController.Node) -> str

View File

@ -9,7 +9,7 @@ from threading import Thread, Event
from time import time from time import time
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable
from .job import TrainsJob from .job import ClearmlJob
from .parameters import Parameter from .parameters import Parameter
from ..backend_interface.util import get_or_create_project from ..backend_interface.util import get_or_create_project
from ..logger import Logger from ..logger import Logger
@ -58,7 +58,7 @@ class Objective(object):
self.extremum = extremum self.extremum = extremum
def get_objective(self, task_id): def get_objective(self, task_id):
# type: (Union[str, Task, TrainsJob]) -> Optional[float] # type: (Union[str, Task, ClearmlJob]) -> Optional[float]
""" """
Return a specific task scalar value based on the objective settings (title/series). Return a specific task scalar value based on the objective settings (title/series).
@ -71,7 +71,7 @@ class Objective(object):
if isinstance(task_id, Task): if isinstance(task_id, Task):
task_id = task_id.id task_id = task_id.id
elif isinstance(task_id, TrainsJob): elif isinstance(task_id, ClearmlJob):
task_id = task_id.task_id() task_id = task_id.task_id()
# noinspection PyBroadException, Py # noinspection PyBroadException, Py
@ -97,7 +97,7 @@ class Objective(object):
return None return None
def get_current_raw_objective(self, task): def get_current_raw_objective(self, task):
# type: (Union[TrainsJob, Task]) -> (int, float) # type: (Union[ClearmlJob, Task]) -> (int, float)
""" """
Return the current raw value (without sign normalization) of the objective. Return the current raw value (without sign normalization) of the objective.
@ -108,7 +108,7 @@ class Objective(object):
""" """
if isinstance(task, Task): if isinstance(task, Task):
task_id = task.id task_id = task.id
elif isinstance(task, TrainsJob): elif isinstance(task, ClearmlJob):
task_id = task.task_id() task_id = task.task_id()
else: else:
task_id = task task_id = task
@ -162,7 +162,7 @@ class Objective(object):
return self.title, self.series return self.title, self.series
def get_normalized_objective(self, task_id): def get_normalized_objective(self, task_id):
# type: (Union[str, Task, TrainsJob]) -> Optional[float] # type: (Union[str, Task, ClearmlJob]) -> Optional[float]
""" """
Return a normalized task scalar value based on the objective settings (title/series). Return a normalized task scalar value based on the objective settings (title/series).
I.e. objective is always to maximize the returned value I.e. objective is always to maximize the returned value
@ -269,7 +269,7 @@ class SearchStrategy(object):
The base search strategy class. Inherit this class to implement your custom strategy. The base search strategy class. Inherit this class to implement your custom strategy.
""" """
_tag = 'optimization' _tag = 'optimization'
_job_class = TrainsJob # type: TrainsJob _job_class = ClearmlJob # type: ClearmlJob
def __init__( def __init__(
self, self,
@ -414,7 +414,7 @@ class SearchStrategy(object):
return bool(self._current_jobs) return bool(self._current_jobs)
def create_job(self): def create_job(self):
# type: () -> Optional[TrainsJob] # type: () -> Optional[ClearmlJob]
""" """
Abstract helper function. Implementation is not required. Default use in process_step default implementation Abstract helper function. Implementation is not required. Default use in process_step default implementation
Create a new job if needed. return the newly created job. If no job needs to be created, return ``None``. Create a new job if needed. return the newly created job. If no job needs to be created, return ``None``.
@ -424,7 +424,7 @@ class SearchStrategy(object):
return None return None
def monitor_job(self, job): def monitor_job(self, job):
# type: (TrainsJob) -> bool # type: (ClearmlJob) -> bool
""" """
Helper function, Implementation is not required. Default use in process_step default implementation. Helper function, Implementation is not required. Default use in process_step default implementation.
Check if the job needs to be aborted or already completed. Check if the job needs to be aborted or already completed.
@ -434,7 +434,7 @@ class SearchStrategy(object):
If there is a budget limitation, this call should update If there is a budget limitation, this call should update
``self.budget.compute_time.update`` / ``self.budget.iterations.update`` ``self.budget.compute_time.update`` / ``self.budget.iterations.update``
:param TrainsJob job: A ``TrainsJob`` object to monitor. :param ClearmlJob job: A ``TrainsJob`` object to monitor.
:return: False, if the job is no longer relevant. :return: False, if the job is no longer relevant.
""" """
@ -472,7 +472,7 @@ class SearchStrategy(object):
return abort_job return abort_job
def get_running_jobs(self): def get_running_jobs(self):
# type: () -> Sequence[TrainsJob] # type: () -> Sequence[ClearmlJob]
""" """
Return the current running TrainsJobs. Return the current running TrainsJobs.
@ -534,7 +534,7 @@ class SearchStrategy(object):
parent=None, # type: Optional[str] parent=None, # type: Optional[str]
**kwargs # type: Any **kwargs # type: Any
): ):
# type: (...) -> TrainsJob # type: (...) -> ClearmlJob
""" """
Create a Job using the specified arguments, ``TrainsJob`` for details. Create a Job using the specified arguments, ``TrainsJob`` for details.
@ -564,11 +564,11 @@ class SearchStrategy(object):
return new_job return new_job
def set_job_class(self, job_class): def set_job_class(self, job_class):
# type: (TrainsJob) -> () # type: (ClearmlJob) -> ()
""" """
Set the class to use for the :meth:`helper_create_job` function. Set the class to use for the :meth:`helper_create_job` function.
:param TrainsJob job_class: The Job Class type. :param ClearmlJob job_class: The Job Class type.
""" """
self._job_class = job_class self._job_class = job_class
@ -643,7 +643,7 @@ class SearchStrategy(object):
return self._job_project.get(parent_task_id) return self._job_project.get(parent_task_id)
def _get_job_iterations(self, job): def _get_job_iterations(self, job):
# type: (Union[TrainsJob, Task]) -> int # type: (Union[ClearmlJob, Task]) -> int
iteration_value = self._objective_metric.get_current_raw_objective(job) iteration_value = self._objective_metric.get_current_raw_objective(job)
return iteration_value[0] if iteration_value else -1 return iteration_value[0] if iteration_value else -1
@ -788,7 +788,7 @@ class GridSearch(SearchStrategy):
self._param_iterator = None self._param_iterator = None
def create_job(self): def create_job(self):
# type: () -> Optional[TrainsJob] # type: () -> Optional[ClearmlJob]
""" """
Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``. Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``.
@ -863,7 +863,7 @@ class RandomSearch(SearchStrategy):
self._hyper_parameters_collection = set() self._hyper_parameters_collection = set()
def create_job(self): def create_job(self):
# type: () -> Optional[TrainsJob] # type: () -> Optional[ClearmlJob]
""" """
Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``. Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``.
@ -1299,11 +1299,11 @@ class HyperParameterOptimizer(object):
return self.optimizer return self.optimizer
def set_default_job_class(self, job_class): def set_default_job_class(self, job_class):
# type: (TrainsJob) -> () # type: (ClearmlJob) -> ()
""" """
Set the Job class to use when the optimizer spawns new Jobs. Set the Job class to use when the optimizer spawns new Jobs.
:param TrainsJob job_class: The Job Class type. :param ClearmlJob job_class: The Job Class type.
""" """
self.optimizer.set_job_class(job_class) self.optimizer.set_job_class(job_class)
@ -1709,7 +1709,7 @@ class HyperParameterOptimizer(object):
return completed_value, obj_values return completed_value, obj_values
def _get_last_value(self, response): def _get_last_value(self, response):
metrics, title, series, values = TrainsJob.get_metric_req_params(self.objective_metric.title, metrics, title, series, values = ClearmlJob.get_metric_req_params(self.objective_metric.title,
self.objective_metric.series) self.objective_metric.series)
last_values = response.response_data["task"]['last_metrics'][title][series] last_values = response.response_data["task"]['last_metrics'][title][series]
return last_values return last_values