Fix Multi-pipeline support

This commit is contained in:
allegroai 2022-01-10 00:00:38 +02:00
parent 2c214e9848
commit 51c8c84bc4

View File

@ -20,9 +20,11 @@ from .. import Logger
from ..automation import ClearmlJob from ..automation import ClearmlJob
from ..backend_interface.task.populate import CreateFromFunction from ..backend_interface.task.populate import CreateFromFunction
from ..backend_interface.util import get_or_create_project, exact_match_regex from ..backend_interface.util import get_or_create_project, exact_match_regex
from ..config import get_remote_task_id
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
from ..model import BaseModel, OutputModel from ..model import BaseModel, OutputModel
from ..task import Task from ..task import Task
from ..utilities.process.mp import leave_process
from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary from ..utilities.proxy_object import LazyEvalWrapper, flatten_dictionary
@ -131,6 +133,7 @@ class PipelineController(object):
self._thread = None self._thread = None
self._pipeline_args = dict() self._pipeline_args = dict()
self._pipeline_args_desc = dict() self._pipeline_args_desc = dict()
self._pipeline_args_type = dict()
self._stop_event = None self._stop_event = None
self._experiment_created_cb = None self._experiment_created_cb = None
self._experiment_completed_cb = None self._experiment_completed_cb = None
@ -980,8 +983,8 @@ class PipelineController(object):
# also trigger node monitor scanning # also trigger node monitor scanning
self._scan_monitored_nodes() self._scan_monitored_nodes()
def add_parameter(self, name, default=None, description=None): def add_parameter(self, name, default=None, description=None, param_type=None):
# type: (str, Optional[Any], Optional[str]) -> None # type: (str, Optional[Any], Optional[str], Optional[str]) -> None
""" """
Add a parameter to the pipeline Task. Add a parameter to the pipeline Task.
The parameter can be used as input parameter for any step in the pipeline. The parameter can be used as input parameter for any step in the pipeline.
@ -992,10 +995,13 @@ class PipelineController(object):
:param name: String name of the parameter. :param name: String name of the parameter.
:param default: Default value to be put as the default value (can be later changed in the UI) :param default: Default value to be put as the default value (can be later changed in the UI)
:param description: String description of the parameter and its usage in the pipeline :param description: String description of the parameter and its usage in the pipeline
:param param_type: Optional, parameter type information (to used as hint for casting and description)
""" """
self._pipeline_args[str(name)] = str(default or '') self._pipeline_args[str(name)] = default
if description: if description:
self._pipeline_args_desc[str(name)] = str(description) self._pipeline_args_desc[str(name)] = str(description)
if param_type:
self._pipeline_args_type[str(name)] = param_type
def get_parameters(self): def get_parameters(self):
# type: () -> dict # type: () -> dict
@ -1148,6 +1154,7 @@ class PipelineController(object):
self._task._set_parameters( self._task._set_parameters(
{'{}/{}'.format(self._args_section, k): v for k, v in params.items()}, {'{}/{}'.format(self._args_section, k): v for k, v in params.items()},
__parameters_descriptions=self._pipeline_args_desc, __parameters_descriptions=self._pipeline_args_desc,
__parameters_types=self._pipeline_args_type,
__update=True, __update=True,
) )
params['_continue_pipeline_'] = False params['_continue_pipeline_'] = False
@ -2197,6 +2204,7 @@ class PipelineDecorator(PipelineController):
_debug_execute_step_function = False _debug_execute_step_function = False
_default_execution_queue = None _default_execution_queue = None
_multi_pipeline_instances = [] _multi_pipeline_instances = []
_multi_pipeline_call_counter = -1
_atexit_registered = False _atexit_registered = False
def __init__( def __init__(
@ -2877,6 +2885,7 @@ class PipelineDecorator(PipelineController):
def internal_decorator(*args, **kwargs): def internal_decorator(*args, **kwargs):
pipeline_kwargs = dict(**(kwargs or {})) pipeline_kwargs = dict(**(kwargs or {}))
pipeline_kwargs_types = dict()
inspect_func = inspect.getfullargspec(func) inspect_func = inspect.getfullargspec(func)
if args: if args:
if not inspect_func.args: if not inspect_func.args:
@ -2892,12 +2901,39 @@ class PipelineDecorator(PipelineController):
default_kwargs.update(pipeline_kwargs) default_kwargs.update(pipeline_kwargs)
pipeline_kwargs = default_kwargs pipeline_kwargs = default_kwargs
if inspect_func.annotations:
pipeline_kwargs_types = {
str(k): inspect_func.annotations[k] for k in inspect_func.annotations}
# run the entire pipeline locally, as python functions # run the entire pipeline locally, as python functions
if cls._debug_execute_step_function: if cls._debug_execute_step_function:
ret_val = func(**pipeline_kwargs) ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references() LazyEvalWrapper.trigger_all_remote_references()
return ret_val return ret_val
# check if we are in a multi pipeline
force_single_multi_pipeline_call = False
if multi_instance_support and cls._multi_pipeline_call_counter >= 0:
# check if we are running remotely
if not Task.running_locally():
# get the main Task property
t = Task.get_task(task_id=get_remote_task_id())
if str(t.task_type) == str(Task.TaskTypes.controller):
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
multi_pipeline_call_counter = int(
t._get_runtime_properties().get('multi_pipeline_counter', None))
# NOTICE! if this is not our call we LEAVE immediately
# check if this is our call to start, if not we will wait for the next one
if multi_pipeline_call_counter != cls._multi_pipeline_call_counter:
return
except Exception:
# this is not the one, so we should just run the first
# instance and leave immediately
force_single_multi_pipeline_call = True
if default_queue: if default_queue:
cls.set_default_execution_queue(default_queue) cls.set_default_execution_queue(default_queue)
@ -2918,13 +2954,23 @@ class PipelineDecorator(PipelineController):
a_pipeline._clearml_job_class.register_hashing_callback(a_pipeline._adjust_task_hashing) a_pipeline._clearml_job_class.register_hashing_callback(a_pipeline._adjust_task_hashing)
# add pipeline arguments # add pipeline arguments
if pipeline_kwargs: for k in pipeline_kwargs:
a_pipeline.get_parameters().update(pipeline_kwargs) a_pipeline.add_parameter(
name=k,
default=pipeline_kwargs.get(k),
param_type=pipeline_kwargs_types.get(k)
)
# sync multi-pipeline call counter (so we know which one to skip)
if Task.running_locally() and multi_instance_support and cls._multi_pipeline_call_counter >= 0:
# noinspection PyProtectedMember
a_pipeline._task._set_runtime_properties(
dict(multi_pipeline_counter=str(cls._multi_pipeline_call_counter)))
# serialize / deserialize state only if we are running locally # serialize / deserialize state only if we are running locally
a_pipeline._start(wait=False) a_pipeline._start(wait=False)
# sync arguments back # sync arguments back (post deserialization and casting back)
for k in pipeline_kwargs.keys(): for k in pipeline_kwargs.keys():
if k in a_pipeline.get_parameters(): if k in a_pipeline.get_parameters():
pipeline_kwargs[k] = a_pipeline.get_parameters()[k] pipeline_kwargs[k] = a_pipeline.get_parameters()[k]
@ -2968,6 +3014,12 @@ class PipelineDecorator(PipelineController):
# now we can raise the exception # now we can raise the exception
if triggered_exception: if triggered_exception:
raise triggered_exception raise triggered_exception
# Make sure that if we do not need to run all pipelines we forcefully leave the process
if force_single_multi_pipeline_call:
leave_process()
# we will never get here
return pipeline_result return pipeline_result
if multi_instance_support: if multi_instance_support:
@ -3045,6 +3097,8 @@ class PipelineDecorator(PipelineController):
""" """
def internal_decorator(*args, **kwargs): def internal_decorator(*args, **kwargs):
cls._multi_pipeline_call_counter += 1
# if this is a debug run just call the function (no parallelization). # if this is a debug run just call the function (no parallelization).
if cls._debug_execute_step_function: if cls._debug_execute_step_function:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -3069,7 +3123,7 @@ class PipelineDecorator(PipelineController):
# make sure we wait for the subprocess. # make sure we wait for the subprocess.
p.daemon = False p.daemon = False
p.start() p.start()
if parallel: if parallel and Task.running_locally():
cls._multi_pipeline_instances.append((p, queue)) cls._multi_pipeline_instances.append((p, queue))
return return
else: else: