mirror of
https://github.com/clearml/clearml
synced 2025-02-01 09:36:49 +00:00
Improve pipeline controller, add callbacks
This commit is contained in:
parent
cf28551d21
commit
0013c5851e
@ -1,7 +1,7 @@
|
||||
from .parameters import UniformParameterRange, DiscreteParameterRange, UniformIntegerParameterRange, ParameterSet
|
||||
from .optimization import GridSearch, RandomSearch, HyperParameterOptimizer, Objective
|
||||
from .job import TrainsJob
|
||||
from .job import ClearmlJob
|
||||
from .controller import PipelineController
|
||||
|
||||
__all__ = ["UniformParameterRange", "DiscreteParameterRange", "UniformIntegerParameterRange", "ParameterSet",
|
||||
"GridSearch", "RandomSearch", "HyperParameterOptimizer", "Objective", "TrainsJob", "PipelineController"]
|
||||
"GridSearch", "RandomSearch", "HyperParameterOptimizer", "Objective", "ClearmlJob", "PipelineController"]
|
||||
|
@ -8,9 +8,11 @@ from time import time
|
||||
from attr import attrib, attrs
|
||||
from typing import Sequence, Optional, Mapping, Callable, Any, Union
|
||||
|
||||
from ..backend_interface.util import get_or_create_project
|
||||
from ..config import get_remote_task_id
|
||||
from ..debugging.log import LoggerRoot
|
||||
from ..task import Task
|
||||
from ..automation import TrainsJob
|
||||
from ..automation import ClearmlJob
|
||||
from ..model import BaseModel
|
||||
|
||||
|
||||
@ -34,9 +36,10 @@ class PipelineController(object):
|
||||
parents = attrib(type=list, default=[])
|
||||
timeout = attrib(type=float, default=None)
|
||||
parameters = attrib(type=dict, default={})
|
||||
task_overrides = attrib(type=dict, default={})
|
||||
executed = attrib(type=str, default=None)
|
||||
clone_task = attrib(type=bool, default=True)
|
||||
job = attrib(type=TrainsJob, default=None)
|
||||
job = attrib(type=ClearmlJob, default=None)
|
||||
skip_job = attrib(type=bool, default=False)
|
||||
|
||||
def __init__(
|
||||
@ -47,6 +50,9 @@ class PipelineController(object):
|
||||
auto_connect_task=True, # type: Union[bool, Task]
|
||||
always_create_task=False, # type: bool
|
||||
add_pipeline_tags=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
pipeline_name=None, # type: Optional[str]
|
||||
pipeline_project=None, # type: Optional[str]
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@ -70,6 +76,11 @@ class PipelineController(object):
|
||||
- ``False`` - Use the :py:meth:`task.Task.current_task` (if exists) to report statistics.
|
||||
:param bool add_pipeline_tags: (default: False) if True, add `pipe: <pipeline_task_id>` tag to all
|
||||
steps (Tasks) created by this pipeline.
|
||||
|
||||
:param str target_project: If provided, all pipeline steps are cloned into the target project
|
||||
|
||||
:param pipeline_name: Optional, provide pipeline name if main Task is not present (default current date)
|
||||
:param pipeline_project: Optional, provide project storing the pipeline if main Task is not present
|
||||
"""
|
||||
self._nodes = {}
|
||||
self._running_nodes = []
|
||||
@ -81,14 +92,17 @@ class PipelineController(object):
|
||||
self._stop_event = None
|
||||
self._experiment_created_cb = None
|
||||
self._experiment_completed_cb = None
|
||||
self._pre_step_callbacks = {}
|
||||
self._post_step_callbacks = {}
|
||||
self._target_project = target_project or ''
|
||||
self._add_pipeline_tags = add_pipeline_tags
|
||||
self._task = auto_connect_task if isinstance(auto_connect_task, Task) else Task.current_task()
|
||||
self._step_ref_pattern = re.compile(self._step_pattern)
|
||||
self._reporting_lock = RLock()
|
||||
if not self._task and always_create_task:
|
||||
self._task = Task.init(
|
||||
project_name='Pipelines',
|
||||
task_name='Pipeline {}'.format(datetime.now()),
|
||||
project_name=pipeline_project or 'Pipelines',
|
||||
task_name=pipeline_name or 'Pipeline {}'.format(datetime.now()),
|
||||
task_type=Task.TaskTypes.controller,
|
||||
)
|
||||
|
||||
@ -103,11 +117,15 @@ class PipelineController(object):
|
||||
base_task_id=None, # type: Optional[str]
|
||||
parents=None, # type: Optional[Sequence[str]]
|
||||
parameter_override=None, # type: Optional[Mapping[str, Any]]
|
||||
task_overrides=None, # type: Optional[Mapping[str, Any]]
|
||||
execution_queue=None, # type: Optional[str]
|
||||
time_limit=None, # type: Optional[float]
|
||||
base_task_project=None, # type: Optional[str]
|
||||
base_task_name=None, # type: Optional[str]
|
||||
clone_base_task=True, # type: bool
|
||||
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -131,6 +149,15 @@ class PipelineController(object):
|
||||
parameter_override={'Args/input_file': '${stage3.parameters.Args/input_file}' }
|
||||
Task ID
|
||||
parameter_override={'Args/input_file': '${stage3.id}' }
|
||||
:param dict task_overrides: Optional task section overriding dictionary.
|
||||
The dict values can reference a previously executed step using the following form '${step_name}'
|
||||
Examples:
|
||||
clear git repository commit ID
|
||||
parameter_override={'script.version_num': '' }
|
||||
git repository commit branch
|
||||
parameter_override={'script.branch': '${stage1.script.branch}' }
|
||||
container image
|
||||
parameter_override={'container.image': '${stage1.container.image}' }
|
||||
:param str execution_queue: Optional, the queue to use for executing this specific step.
|
||||
If not provided, the task will be sent to the default execution queue, as defined on the class
|
||||
:param float time_limit: Default None, no time limit.
|
||||
@ -141,10 +168,49 @@ class PipelineController(object):
|
||||
use the base_task_project and base_task_name combination to retrieve the base_task_id to use for the step.
|
||||
:param bool clone_base_task: If True (default) the pipeline will clone the base task, and modify/enqueue
|
||||
the cloned Task. If False, the base-task is used directly, notice it has to be in draft-mode (created).
|
||||
:param Callable pre_execute_callback: Callback function, called when the step (Task) is created
|
||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||
|
||||
If the callback returned value is `False`,
|
||||
the Node is skipped and so is any node in the DAG that relies on this node.
|
||||
|
||||
Notice the `parameters` are already parsed,
|
||||
e.g. `${step1.parameters.Args/param}` is replaced with relevant value.
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def step_created_callback(
|
||||
pipeline, # type: PipelineController,
|
||||
node, # type: PipelineController.Node,
|
||||
parameters, # type: dict
|
||||
):
|
||||
pass
|
||||
|
||||
:param Callable post_execute_callback: Callback function, called when a step (Task) is completed
|
||||
and it other jobs are executed. Allows a user to modify the Task status after completion.
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
def step_completed_callback(
|
||||
pipeline, # type: PipelineController,
|
||||
node, # type: PipelineController.Node,
|
||||
):
|
||||
pass
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
|
||||
# always store callback functions (even when running remotely)
|
||||
if pre_execute_callback:
|
||||
self._pre_step_callbacks[name] = pre_execute_callback
|
||||
if post_execute_callback:
|
||||
self._post_step_callbacks[name] = post_execute_callback
|
||||
|
||||
# when running remotely do nothing, we will deserialize ourselves when we start
|
||||
if self._has_stored_configuration():
|
||||
# if we are not cloning a Task, we assume this step is created from code, not from the configuration
|
||||
if clone_base_task and self._has_stored_configuration():
|
||||
return True
|
||||
|
||||
if name in self._nodes:
|
||||
@ -168,13 +234,16 @@ class PipelineController(object):
|
||||
queue=execution_queue, timeout=time_limit,
|
||||
parameters=parameter_override or {},
|
||||
clone_task=clone_base_task,
|
||||
task_overrides=task_overrides,
|
||||
)
|
||||
|
||||
if self._task and not self._task.running_locally():
|
||||
self.update_execution_plot()
|
||||
|
||||
return True
|
||||
|
||||
def start(
|
||||
self,
|
||||
run_remotely=False, # type: Union[bool, str]
|
||||
step_task_created_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
step_task_completed_callback=None # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||
):
|
||||
@ -183,13 +252,10 @@ class PipelineController(object):
|
||||
Start the pipeline controller.
|
||||
If the calling process is stopped, then the controller stops as well.
|
||||
|
||||
:param bool run_remotely: (default False), If True stop the current process and continue execution
|
||||
on a remote machine. This is done by calling the Task.execute_remotely with the queue name 'services'.
|
||||
If `run_remotely` is a string, it will specify the execution queue for the pipeline remote execution.
|
||||
:param Callable step_task_created_callback: Callback function, called when a step (Task) is created
|
||||
and before it is sent for execution. Allows a user to modify the Task before launch.
|
||||
Use `node.job` to access the TrainsJob object, or `node.job.task` to directly access the Task object.
|
||||
`parameters` are the configuration arguments passed to the TrainsJob.
|
||||
Use `node.job` to access the ClearmlJob object, or `node.job.task` to directly access the Task object.
|
||||
`parameters` are the configuration arguments passed to the ClearmlJob.
|
||||
|
||||
If the callback returned value is `False`,
|
||||
the Node is skipped and so is any node in the DAG that relies on this node.
|
||||
@ -224,14 +290,8 @@ class PipelineController(object):
|
||||
if self._thread:
|
||||
return True
|
||||
|
||||
# serialize pipeline state
|
||||
pipeline_dag = self._serialize()
|
||||
self._task.connect_configuration(pipeline_dag, name=self._config_section)
|
||||
params = {'continue_pipeline': False,
|
||||
'default_queue': self._default_execution_queue,
|
||||
'add_pipeline_tags': self._add_pipeline_tags,
|
||||
}
|
||||
self._task.connect(params, name=self._config_section)
|
||||
params, pipeline_dag = self._serialize_pipeline_task()
|
||||
|
||||
# deserialize back pipeline state
|
||||
if not params['continue_pipeline']:
|
||||
for k in pipeline_dag:
|
||||
@ -239,6 +299,7 @@ class PipelineController(object):
|
||||
|
||||
self._default_execution_queue = params['default_queue']
|
||||
self._add_pipeline_tags = params['add_pipeline_tags']
|
||||
self._target_project = params['target_project'] or ''
|
||||
self._deserialize(pipeline_dag)
|
||||
|
||||
# if we continue the pipeline, make sure that we re-execute failed tasks
|
||||
@ -252,11 +313,6 @@ class PipelineController(object):
|
||||
"it has either inaccessible nodes, or contains cycles")
|
||||
|
||||
self.update_execution_plot()
|
||||
print('update_execution_plot!!!!')
|
||||
|
||||
if run_remotely:
|
||||
self._task.execute_remotely(queue_name='services' if not isinstance(run_remotely, str) else run_remotely)
|
||||
# we will not get here if we are not running remotely
|
||||
|
||||
self._start_time = time()
|
||||
self._stop_event = Event()
|
||||
@ -267,6 +323,38 @@ class PipelineController(object):
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def start_remotely(self, queue='services', exit_process=True):
|
||||
# type: (str, bool) -> Task
|
||||
"""
|
||||
Start the current pipeline remotely (on the selected services queue)
|
||||
The current process will be stopped if exit_process is True.
|
||||
|
||||
:param queue: queue name to launch the pipeline on
|
||||
:param exit_process: If True exit the current process after launching on the enqueuing on the queue
|
||||
|
||||
:return: The remote Task object
|
||||
"""
|
||||
if not self._task:
|
||||
raise ValueError(
|
||||
"Could not find main Task, "
|
||||
"PipelineController must be created with `always_create_task=True`")
|
||||
|
||||
# serialize state only if we are running locally
|
||||
if Task.running_locally() or not self._task.is_main_task():
|
||||
self._serialize_pipeline_task()
|
||||
self.update_execution_plot()
|
||||
|
||||
# stop current Task and execute remotely or no-op
|
||||
self._task.execute_remotely(queue_name=queue, exit_process=exit_process, clone=False)
|
||||
|
||||
if not Task.running_locally() and self._task.is_main_task():
|
||||
self.start()
|
||||
self.wait()
|
||||
self.stop()
|
||||
exit(0)
|
||||
else:
|
||||
return self._task
|
||||
|
||||
def stop(self, timeout=None):
|
||||
# type: (Optional[float]) -> ()
|
||||
"""
|
||||
@ -337,7 +425,7 @@ class PipelineController(object):
|
||||
{
|
||||
'stage1' : Node() {
|
||||
name: 'stage1'
|
||||
job: TrainsJob
|
||||
job: ClearmlJob
|
||||
...
|
||||
},
|
||||
}
|
||||
@ -363,6 +451,27 @@ class PipelineController(object):
|
||||
"""
|
||||
return {k: n for k, n in self._nodes.items() if k in self._running_nodes}
|
||||
|
||||
def _serialize_pipeline_task(self):
|
||||
# type: () -> (dict, dict)
|
||||
"""
|
||||
Serialize current pipeline state into the main Task
|
||||
|
||||
:return: params, pipeline_dag
|
||||
"""
|
||||
params = {'continue_pipeline': False,
|
||||
'default_queue': self._default_execution_queue,
|
||||
'add_pipeline_tags': self._add_pipeline_tags,
|
||||
'target_project': self._target_project,
|
||||
}
|
||||
pipeline_dag = self._serialize()
|
||||
|
||||
# serialize pipeline state
|
||||
if self._task:
|
||||
self._task.connect_configuration(pipeline_dag, name=self._config_section)
|
||||
self._task.connect(params, name=self._config_section)
|
||||
|
||||
return params, pipeline_dag
|
||||
|
||||
def _serialize(self):
|
||||
# type: () -> dict
|
||||
"""
|
||||
@ -382,7 +491,9 @@ class PipelineController(object):
|
||||
This will be used to create the DAG from the dict stored on the Task, when running remotely.
|
||||
:return:
|
||||
"""
|
||||
self._nodes = {k: self.Node(name=k, **v) for k, v in dag_dict.items()}
|
||||
self._nodes = {
|
||||
k: self.Node(name=k, **v) if not v.get('clone_task') or k not in self._nodes else self._nodes[k]
|
||||
for k, v in dag_dict.items()}
|
||||
|
||||
def _has_stored_configuration(self):
|
||||
"""
|
||||
@ -461,7 +572,7 @@ class PipelineController(object):
|
||||
def _launch_node(self, node):
|
||||
# type: (PipelineController.Node) -> ()
|
||||
"""
|
||||
Launch a single node (create and enqueue a TrainsJob)
|
||||
Launch a single node (create and enqueue a ClearmlJob)
|
||||
|
||||
:param node: Node to launch
|
||||
:return: Return True if a new job was launched
|
||||
@ -473,13 +584,31 @@ class PipelineController(object):
|
||||
for k, v in node.parameters.items():
|
||||
updated_hyper_parameters[k] = self._parse_step_ref(v)
|
||||
|
||||
node.job = TrainsJob(
|
||||
task_overrides = self._parse_task_overrides(node.task_overrides) if node.task_overrides else None
|
||||
|
||||
extra_args = dict()
|
||||
if self._target_project:
|
||||
extra_args['project'] = get_or_create_project(
|
||||
session=self._task.session if self._task else Task.default_session,
|
||||
project_name=self._target_project)
|
||||
|
||||
skip_node = None
|
||||
if self._pre_step_callbacks.get(node.name):
|
||||
skip_node = self._pre_step_callbacks[node.name](self, node, updated_hyper_parameters)
|
||||
|
||||
if skip_node is False:
|
||||
node.skip_job = True
|
||||
return True
|
||||
|
||||
node.job = ClearmlJob(
|
||||
base_task_id=node.base_task_id, parameter_override=updated_hyper_parameters,
|
||||
tags=['pipe: {}'.format(self._task.id)] if self._add_pipeline_tags and self._task else None,
|
||||
parent=self._task.id if self._task else None,
|
||||
disable_clone_task=not node.clone_task,
|
||||
task_overrides=task_overrides,
|
||||
**extra_args
|
||||
)
|
||||
skip_node = None
|
||||
|
||||
if self._experiment_created_cb:
|
||||
skip_node = self._experiment_created_cb(self, node, updated_hyper_parameters)
|
||||
|
||||
@ -508,6 +637,9 @@ class PipelineController(object):
|
||||
"""
|
||||
Update sankey diagram of the current pipeline
|
||||
"""
|
||||
if not self._task:
|
||||
return
|
||||
|
||||
sankey_node = dict(
|
||||
label=[],
|
||||
color=[],
|
||||
@ -618,9 +750,10 @@ class PipelineController(object):
|
||||
|
||||
def _force_task_configuration_update(self):
|
||||
pipeline_dag = self._serialize()
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_configuration(
|
||||
name=self._config_section, config_type='dictionary', config_dict=pipeline_dag)
|
||||
if self._task:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_configuration(
|
||||
name=self._config_section, config_type='dictionary', config_dict=pipeline_dag)
|
||||
|
||||
def _daemon(self):
|
||||
# type: () -> ()
|
||||
@ -666,9 +799,15 @@ class PipelineController(object):
|
||||
continue
|
||||
|
||||
# callback on completed jobs
|
||||
if self._experiment_completed_cb:
|
||||
if self._experiment_completed_cb or self._post_step_callbacks:
|
||||
for job in completed_jobs:
|
||||
self._experiment_completed_cb(self, job)
|
||||
job_node = self._nodes.get(job)
|
||||
if not job_node:
|
||||
continue
|
||||
if self._experiment_completed_cb:
|
||||
self._experiment_completed_cb(self, job_node)
|
||||
if self._post_step_callbacks.get(job_node.name):
|
||||
self._post_step_callbacks[job_node.name](self, job_node)
|
||||
|
||||
# Pull the next jobs in the pipeline, based on the completed list
|
||||
next_nodes = []
|
||||
@ -795,16 +934,24 @@ class PipelineController(object):
|
||||
raise ValueError("Could not parse reference '{}'".format(step_ref_string))
|
||||
prev_step = parts[0]
|
||||
input_type = parts[1].lower()
|
||||
if prev_step not in self._nodes or not self._nodes[prev_step].job:
|
||||
raise ValueError("Could not parse reference '{}', step {} could not be found".format(
|
||||
if prev_step not in self._nodes or (
|
||||
not self._nodes[prev_step].job and
|
||||
not self._nodes[prev_step].executed and
|
||||
not self._nodes[prev_step].base_task_id
|
||||
):
|
||||
raise ValueError("Could not parse reference '{}', step '{}' could not be found".format(
|
||||
step_ref_string, prev_step))
|
||||
if input_type not in ('artifacts', 'parameters', 'models', 'id'):
|
||||
raise ValueError("Could not parse reference '{}', type {} not valid".format(step_ref_string, input_type))
|
||||
|
||||
if input_type not in (
|
||||
'artifacts', 'parameters', 'models', 'id',
|
||||
'script', 'execution', 'container', 'output',
|
||||
'comment', 'models', 'tags', 'system_tags', 'project'):
|
||||
raise ValueError("Could not parse reference '{}', type '{}' not valid".format(step_ref_string, input_type))
|
||||
if input_type != 'id' and len(parts) < 3:
|
||||
raise ValueError("Could not parse reference '{}', missing fields in {}".format(step_ref_string, parts))
|
||||
raise ValueError("Could not parse reference '{}', missing fields in '{}'".format(step_ref_string, parts))
|
||||
|
||||
task = self._nodes[prev_step].job.task if self._nodes[prev_step].job \
|
||||
else Task.get_task(task_id=self._nodes[prev_step].executed)
|
||||
else Task.get_task(task_id=self._nodes[prev_step].executed or self._nodes[prev_step].base_task_id)
|
||||
task.reload()
|
||||
if input_type == 'artifacts':
|
||||
# fix \. to use . in artifacts
|
||||
@ -842,9 +989,14 @@ class PipelineController(object):
|
||||
'.'.join(parts[1:]), prev_step, parts[3]))
|
||||
|
||||
return str(getattr(model, parts[4]))
|
||||
|
||||
elif input_type == 'id':
|
||||
return task.id
|
||||
elif input_type in (
|
||||
'script', 'execution', 'container', 'output',
|
||||
'comment', 'models', 'tags', 'system_tags', 'project'):
|
||||
# noinspection PyProtectedMember
|
||||
return task._get_task_property('.'.join(parts[1:]))
|
||||
|
||||
return None
|
||||
|
||||
def _parse_step_ref(self, value):
|
||||
@ -864,6 +1016,19 @@ class PipelineController(object):
|
||||
updated_value = updated_value.replace(g, new_val, 1)
|
||||
return updated_value
|
||||
|
||||
def _parse_task_overrides(self, task_overrides):
|
||||
# type: (dict) -> dict
|
||||
"""
|
||||
Return the step reference. For example "${step1.parameters.Args/param}"
|
||||
:param task_overrides: string
|
||||
:return:
|
||||
"""
|
||||
updated_overrides = {}
|
||||
for k, v in task_overrides.items():
|
||||
updated_overrides[k] = self._parse_step_ref(v)
|
||||
|
||||
return updated_overrides
|
||||
|
||||
@classmethod
|
||||
def __get_node_status(cls, a_node):
|
||||
# type: (PipelineController.Node) -> str
|
||||
|
@ -9,7 +9,7 @@ from threading import Thread, Event
|
||||
from time import time
|
||||
from typing import List, Set, Union, Any, Sequence, Optional, Mapping, Callable
|
||||
|
||||
from .job import TrainsJob
|
||||
from .job import ClearmlJob
|
||||
from .parameters import Parameter
|
||||
from ..backend_interface.util import get_or_create_project
|
||||
from ..logger import Logger
|
||||
@ -58,7 +58,7 @@ class Objective(object):
|
||||
self.extremum = extremum
|
||||
|
||||
def get_objective(self, task_id):
|
||||
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
|
||||
# type: (Union[str, Task, ClearmlJob]) -> Optional[float]
|
||||
"""
|
||||
Return a specific task scalar value based on the objective settings (title/series).
|
||||
|
||||
@ -71,7 +71,7 @@ class Objective(object):
|
||||
|
||||
if isinstance(task_id, Task):
|
||||
task_id = task_id.id
|
||||
elif isinstance(task_id, TrainsJob):
|
||||
elif isinstance(task_id, ClearmlJob):
|
||||
task_id = task_id.task_id()
|
||||
|
||||
# noinspection PyBroadException, Py
|
||||
@ -97,7 +97,7 @@ class Objective(object):
|
||||
return None
|
||||
|
||||
def get_current_raw_objective(self, task):
|
||||
# type: (Union[TrainsJob, Task]) -> (int, float)
|
||||
# type: (Union[ClearmlJob, Task]) -> (int, float)
|
||||
"""
|
||||
Return the current raw value (without sign normalization) of the objective.
|
||||
|
||||
@ -108,7 +108,7 @@ class Objective(object):
|
||||
"""
|
||||
if isinstance(task, Task):
|
||||
task_id = task.id
|
||||
elif isinstance(task, TrainsJob):
|
||||
elif isinstance(task, ClearmlJob):
|
||||
task_id = task.task_id()
|
||||
else:
|
||||
task_id = task
|
||||
@ -162,7 +162,7 @@ class Objective(object):
|
||||
return self.title, self.series
|
||||
|
||||
def get_normalized_objective(self, task_id):
|
||||
# type: (Union[str, Task, TrainsJob]) -> Optional[float]
|
||||
# type: (Union[str, Task, ClearmlJob]) -> Optional[float]
|
||||
"""
|
||||
Return a normalized task scalar value based on the objective settings (title/series).
|
||||
I.e. objective is always to maximize the returned value
|
||||
@ -269,7 +269,7 @@ class SearchStrategy(object):
|
||||
The base search strategy class. Inherit this class to implement your custom strategy.
|
||||
"""
|
||||
_tag = 'optimization'
|
||||
_job_class = TrainsJob # type: TrainsJob
|
||||
_job_class = ClearmlJob # type: ClearmlJob
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -414,7 +414,7 @@ class SearchStrategy(object):
|
||||
return bool(self._current_jobs)
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
# type: () -> Optional[ClearmlJob]
|
||||
"""
|
||||
Abstract helper function. Implementation is not required. Default use in process_step default implementation
|
||||
Create a new job if needed. return the newly created job. If no job needs to be created, return ``None``.
|
||||
@ -424,7 +424,7 @@ class SearchStrategy(object):
|
||||
return None
|
||||
|
||||
def monitor_job(self, job):
|
||||
# type: (TrainsJob) -> bool
|
||||
# type: (ClearmlJob) -> bool
|
||||
"""
|
||||
Helper function, Implementation is not required. Default use in process_step default implementation.
|
||||
Check if the job needs to be aborted or already completed.
|
||||
@ -434,7 +434,7 @@ class SearchStrategy(object):
|
||||
If there is a budget limitation, this call should update
|
||||
``self.budget.compute_time.update`` / ``self.budget.iterations.update``
|
||||
|
||||
:param TrainsJob job: A ``TrainsJob`` object to monitor.
|
||||
:param ClearmlJob job: A ``TrainsJob`` object to monitor.
|
||||
|
||||
:return: False, if the job is no longer relevant.
|
||||
"""
|
||||
@ -472,7 +472,7 @@ class SearchStrategy(object):
|
||||
return abort_job
|
||||
|
||||
def get_running_jobs(self):
|
||||
# type: () -> Sequence[TrainsJob]
|
||||
# type: () -> Sequence[ClearmlJob]
|
||||
"""
|
||||
Return the current running TrainsJobs.
|
||||
|
||||
@ -534,7 +534,7 @@ class SearchStrategy(object):
|
||||
parent=None, # type: Optional[str]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> TrainsJob
|
||||
# type: (...) -> ClearmlJob
|
||||
"""
|
||||
Create a Job using the specified arguments, ``TrainsJob`` for details.
|
||||
|
||||
@ -564,11 +564,11 @@ class SearchStrategy(object):
|
||||
return new_job
|
||||
|
||||
def set_job_class(self, job_class):
|
||||
# type: (TrainsJob) -> ()
|
||||
# type: (ClearmlJob) -> ()
|
||||
"""
|
||||
Set the class to use for the :meth:`helper_create_job` function.
|
||||
|
||||
:param TrainsJob job_class: The Job Class type.
|
||||
:param ClearmlJob job_class: The Job Class type.
|
||||
"""
|
||||
self._job_class = job_class
|
||||
|
||||
@ -643,7 +643,7 @@ class SearchStrategy(object):
|
||||
return self._job_project.get(parent_task_id)
|
||||
|
||||
def _get_job_iterations(self, job):
|
||||
# type: (Union[TrainsJob, Task]) -> int
|
||||
# type: (Union[ClearmlJob, Task]) -> int
|
||||
iteration_value = self._objective_metric.get_current_raw_objective(job)
|
||||
return iteration_value[0] if iteration_value else -1
|
||||
|
||||
@ -788,7 +788,7 @@ class GridSearch(SearchStrategy):
|
||||
self._param_iterator = None
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
# type: () -> Optional[ClearmlJob]
|
||||
"""
|
||||
Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``.
|
||||
|
||||
@ -863,7 +863,7 @@ class RandomSearch(SearchStrategy):
|
||||
self._hyper_parameters_collection = set()
|
||||
|
||||
def create_job(self):
|
||||
# type: () -> Optional[TrainsJob]
|
||||
# type: () -> Optional[ClearmlJob]
|
||||
"""
|
||||
Create a new job if needed. Return the newly created job. If no job needs to be created, return ``None``.
|
||||
|
||||
@ -1299,11 +1299,11 @@ class HyperParameterOptimizer(object):
|
||||
return self.optimizer
|
||||
|
||||
def set_default_job_class(self, job_class):
|
||||
# type: (TrainsJob) -> ()
|
||||
# type: (ClearmlJob) -> ()
|
||||
"""
|
||||
Set the Job class to use when the optimizer spawns new Jobs.
|
||||
|
||||
:param TrainsJob job_class: The Job Class type.
|
||||
:param ClearmlJob job_class: The Job Class type.
|
||||
"""
|
||||
self.optimizer.set_job_class(job_class)
|
||||
|
||||
@ -1709,8 +1709,8 @@ class HyperParameterOptimizer(object):
|
||||
return completed_value, obj_values
|
||||
|
||||
def _get_last_value(self, response):
|
||||
metrics, title, series, values = TrainsJob.get_metric_req_params(self.objective_metric.title,
|
||||
self.objective_metric.series)
|
||||
metrics, title, series, values = ClearmlJob.get_metric_req_params(self.objective_metric.title,
|
||||
self.objective_metric.series)
|
||||
last_values = response.response_data["task"]['last_metrics'][title][series]
|
||||
return last_values
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user