Add support for recursive list, dict, and tuple ref parsing for pipeline controller.add step() parameter override (#1099)

* feat:  Added optional support for list, dicts, tuples in Pipeline parameter_overrides

* style: 🎨 Updated to pass the flake8 formatting guidelines

* docs: 📝 There was a small typo I noticed in the documentation. Two extra '
This commit is contained in:
natephysics 2023-09-09 21:17:49 +02:00 committed by GitHub
parent c922c40d13
commit d458924160
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -81,6 +81,7 @@ class PipelineController(object):
parents = attrib(type=list, default=None) # list of parent DAG steps
timeout = attrib(type=float, default=None) # execution timeout limit
parameters = attrib(type=dict, default=None) # Task hyper-parameters to change
recursively_parse_parameters = attrib(type=bool, default=False) # if True, recursively parse parameters in lists, dicts, or tuples
configurations = attrib(type=dict, default=None) # Task configuration objects to change
task_overrides = attrib(type=dict, default=None) # Task overrides to change
executed = attrib(type=str, default=None) # The actual executed Task ID (None if not executed yet)
@ -368,6 +369,7 @@ class PipelineController(object):
base_task_id=None, # type: Optional[str]
parents=None, # type: Optional[Sequence[str]]
parameter_override=None, # type: Optional[Mapping[str, Any]]
recursively_parse_parameters=False, # type: bool
configuration_overrides=None, # type: Optional[Mapping[str, Union[str, Mapping]]]
task_overrides=None, # type: Optional[Mapping[str, Any]]
execution_queue=None, # type: Optional[str]
@ -405,7 +407,10 @@ class PipelineController(object):
- Parameter access ``parameter_override={'Args/input_file': '${<step_name>.parameters.Args/input_file}' }``
- Pipeline Task argument (see `Pipeline.add_parameter`) ``parameter_override={'Args/input_file': '${pipeline.<pipeline_parameter>}' }``
- Task ID ``parameter_override={'Args/input_file': '${stage3.id}' }``
:param recursively_parse_parameters: If True, recursively parse parameters from parameter_override in lists, dicts, or tuples.
Example:
- ``parameter_override={'Args/input_file': ['${<step_name>.artifacts.<artifact_name>.url}', 'file2.txt']}`` will be correctly parsed.
- ``parameter_override={'Args/input_file': ('${<step_name_1>.parameters.Args/input_file}', '${<step_name_2>.parameters.Args/input_file}')}`` will be correctly parsed.
:param configuration_overrides: Optional, override Task configuration objects.
Expected dictionary of configuration object name and configuration object content.
Examples:
@ -572,6 +577,7 @@ class PipelineController(object):
name=name, base_task_id=base_task_id, parents=parents or [],
queue=execution_queue, timeout=time_limit,
parameters=parameter_override or {},
recursively_parse_parameters=recursively_parse_parameters,
configurations=configuration_overrides,
clone_task=clone_base_task,
task_overrides=task_overrides,
@ -2237,7 +2243,7 @@ class PipelineController(object):
updated_hyper_parameters = {}
for k, v in node.parameters.items():
updated_hyper_parameters[k] = self._parse_step_ref(v)
updated_hyper_parameters[k] = self._parse_step_ref(v, recursive=node.recursively_parse_parameters)
task_overrides = self._parse_task_overrides(node.task_overrides) if node.task_overrides else None
@ -2776,11 +2782,12 @@ class PipelineController(object):
except Exception:
pass
def _parse_step_ref(self, value):
def _parse_step_ref(self, value, recursive=False):
# type: (Any) -> Optional[str]
"""
Return the step reference. For example "${step1.parameters.Args/param}"
:param value: string
:param recursive: if True, recursively parse all values in the dict, list or tuple
:return:
"""
# look for all the step references
@ -2793,6 +2800,18 @@ class PipelineController(object):
if not isinstance(new_val, six.string_types):
return new_val
updated_value = updated_value.replace(g, new_val, 1)
# if we have a dict, list or tuple, we need to recursively update the values
if recursive:
if isinstance(value, dict):
updated_value = {}
for k, v in value.items():
updated_value[k] = self._parse_step_ref(v, recursive=True)
elif isinstance(value, list):
updated_value = [self._parse_step_ref(v, recursive=True) for v in value]
elif isinstance(value, tuple):
updated_value = tuple(self._parse_step_ref(v, recursive=True) for v in value)
return updated_value
def _parse_task_overrides(self, task_overrides):