Fix Pipeline from tasks does not propagate params overwrite

This commit is contained in:
allegroai 2023-02-28 17:03:04 +02:00
parent 08ce79002b
commit 6b32e1d33a

View File

@ -468,18 +468,12 @@ class PipelineController(object):
:return: True if successful
"""
# always store callback functions (even when running remotely)
if pre_execute_callback:
self._pre_step_callbacks[name] = pre_execute_callback
if post_execute_callback:
self._post_step_callbacks[name] = post_execute_callback
# when running remotely do nothing, we will deserialize ourselves when we start
# if we are not cloning a Task, we assume this step is created from code, not from the configuration
if not base_task_factory and clone_base_task and self._has_stored_configuration():
return True
self._verify_node_name(name)
if not base_task_factory and not base_task_id:
@ -2064,10 +2058,16 @@ class PipelineController(object):
visited.append(node.name)
idx = len(visited) - 1
parents = [visited.index(p) for p in node.parents or []]
if node.job and node.job.task_parameter_override is not None:
node.job.task_parameter_override.update(node.parameters or {})
node_params.append(
(node.job.task_parameter_override
if node.job and node.job.task_parameter_override
else node.parameters) or {})
(
node.job.task_parameter_override
if node.job and node.job.task_parameter_override
else node.parameters
)
or {}
)
# sankey_node['label'].append(node.name)
# sankey_node['customdata'].append(
# '<br />'.join('{}: {}'.format(k, v) for k, v in (node.parameters or {}).items()))
@ -2411,6 +2411,8 @@ class PipelineController(object):
self._launch_node, [self._nodes[name] for name in next_nodes])
for name, success in zip(next_nodes, node_launch_success):
if success and not self._nodes[name].skip_job:
if self._nodes[name].job and self._nodes[name].job.task_parameter_override is not None:
self._nodes[name].job.task_parameter_override.update(self._nodes[name].parameters or {})
print('Launching step: {}'.format(name))
print('Parameters:\n{}'.format(
self._nodes[name].job.task_parameter_override if self._nodes[name].job