Add PipelineController always_create_from_code=True (False is previous behavior where we deserialize always from backend when running remotely, new flow means the pipeline is always created from code)

Fix pipeline decorator does not read the pipeline arguments back from the backend when running remotely
This commit is contained in:
allegroai 2023-08-04 19:54:55 +03:00
parent d26ce48dbe
commit c15f012e1b

View File

@ -151,6 +151,7 @@ class PipelineController(object):
repo=None, # type: Optional[str]
repo_branch=None, # type: Optional[str]
repo_commit=None, # type: Optional[str]
always_create_from_code=True, # type: bool
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
):
@ -215,6 +216,9 @@ class PipelineController(object):
Use empty string ("") to disable any repository auto-detection
:param repo_branch: Optional, specify the remote repository branch (Ignored, if local repo path is used)
:param repo_commit: Optional, specify the repository commit ID (Ignored, if local repo path is used)
:param always_create_from_code: If True (default) the pipeline is always constructed from code,
if False, pipeline is generated from pipeline configuration section on the pipeline Task itsef.
this allows to edit (also add/remove) pipeline steps without changing the original codebase
:param artifact_serialization_function: A serialization function that takes one
parameter of any type which is the object to be serialized. The function should return
a `bytes` or `bytearray` object, which represents the serialized object. All parameter/return
@ -244,6 +248,7 @@ class PipelineController(object):
self._start_time = None
self._pipeline_time_limit = None
self._default_execution_queue = None
self._always_create_from_code = bool(always_create_from_code)
self._version = str(version).strip() if version else None
if self._version and not Version.is_valid_version_string(self._version):
raise ValueError(
@ -1413,7 +1418,7 @@ class PipelineController(object):
pipeline_object._nodes = {}
pipeline_object._running_nodes = []
try:
pipeline_object._deserialize(pipeline_task._get_configuration_dict(cls._config_section))
pipeline_object._deserialize(pipeline_task._get_configuration_dict(cls._config_section), force=True)
except Exception:
pass
return pipeline_object
@ -1715,13 +1720,16 @@ class PipelineController(object):
return dag
def _deserialize(self, dag_dict):
# type: (dict) -> ()
def _deserialize(self, dag_dict, force=False):
# type: (dict, bool) -> ()
"""
Restore the DAG from a dictionary.
This will be used to create the DAG from the dict stored on the Task, when running remotely.
:return:
"""
# if we always want to load the pipeline DAG from code, we are skipping the deserialization step
if not force and self._always_create_from_code:
return
# if we do not clone the Task, only merge the parts we can override.
for name in list(self._nodes.keys()):
@ -3329,6 +3337,7 @@ class PipelineDecorator(PipelineController):
repo=repo,
repo_branch=repo_branch,
repo_commit=repo_commit,
always_create_from_code=False,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
)
@ -4310,11 +4319,6 @@ class PipelineDecorator(PipelineController):
a_pipeline._task._set_runtime_properties(
dict(multi_pipeline_counter=str(cls._multi_pipeline_call_counter)))
# sync arguments back (post deserialization and casting back)
for k in pipeline_kwargs.keys():
if k in a_pipeline.get_parameters():
pipeline_kwargs[k] = a_pipeline.get_parameters()[k]
# run the actual pipeline
if not start_controller_locally and \
not PipelineDecorator._debug_execute_step_process and pipeline_execution_queue:
@ -4322,8 +4326,14 @@ class PipelineDecorator(PipelineController):
a_pipeline._task.execute_remotely(queue_name=pipeline_execution_queue)
# when we get here it means we are running remotely
# this will also deserialize the pipeline and arguments
a_pipeline._start(wait=False)
# sync arguments back (post deserialization and casting back)
for k in pipeline_kwargs.keys():
if k in a_pipeline.get_parameters():
pipeline_kwargs[k] = a_pipeline.get_parameters()[k]
# this time the pipeline is executed only on the remote machine
try:
pipeline_result = func(**pipeline_kwargs)