Fix PipelineDecorator.component() ignores *args and crashes with **kwargs

This commit is contained in:
clearml 2025-04-18 16:03:56 +03:00
parent 88a828cd78
commit fa119672d8
2 changed files with 21 additions and 2 deletions

View File

@ -4691,8 +4691,13 @@ class PipelineDecorator(PipelineController):
# resolve all lazy objects if we have any:
kwargs_artifacts = {}
star_args_index = 0
for i, v in enumerate(args):
kwargs[inspect_func.args[i]] = v
if not inspect_func.args or i >= len(inspect_func.args):
kwargs[str(star_args_index)] = v
star_args_index += 1
else:
kwargs[inspect_func.args[i]] = v
# We need to remember when a pipeline step's return value is evaluated by the pipeline
# controller, but not when it's done here (as we would remember the step every time).

View File

@ -680,6 +680,7 @@ class CreateFromFunction(object):
default_task_template_header = """from clearml import Task
from clearml import TaskTypes
from clearml.automation.controller import PipelineDecorator
import inspect
"""
task_template = """{header}
from clearml.utilities.proxy_object import get_basic_type
@ -699,6 +700,11 @@ if __name__ == '__main__':
task.connect(kwargs, name='{kwargs_section}')
function_input_artifacts = {function_input_artifacts}
params = task.get_parameters() or dict()
argspec = inspect.getfullargspec({function_name})
if argspec.varkw is not None or argspec.varargs is not None:
for k, v in params.items():
if k.startswith('{kwargs_section}/'):
kwargs[k.replace('{kwargs_section}/', '', 1)] = v
return_section = '{return_section}'
for k, v in params.items():
if not v or not k.startswith('{input_artifact_section}/'):
@ -710,7 +716,15 @@ if __name__ == '__main__':
kwargs[k] = parent_task.artifacts[artifact_name].get(deserialization_function={artifact_deserialization_function_name})
else:
kwargs[k] = parent_task.get_parameters(cast=True).get(return_section + '/' + artifact_name)
results = {function_name}(**kwargs)
if '0' in kwargs: # *args arguments are present
pos_args = [kwargs.pop(arg, None) for arg in (argspec.args or [])]
other_pos_args_index = 0
while str(other_pos_args_index) in kwargs:
pos_args.append(kwargs.pop(str(other_pos_args_index)))
other_pos_args_index += 1
results = {function_name}(*pos_args, **kwargs)
else:
results = {function_name}(**kwargs)
result_names = {function_return}
if result_names:
if not isinstance(results, (tuple, list)) or len(result_names) == 1: